Compare commits

..

1 Commits

Author SHA1 Message Date
Zoltán Papp
dc8c2edf50 Revert "[client] Add TTL-based refresh to mgmt DNS cache via handler chain (#5945)"
This reverts commit 801de8c68d.
2026-04-23 21:29:46 +02:00
45 changed files with 5535 additions and 2828 deletions

View File

@@ -1,5 +0,0 @@
{
"name": "issue-resolution",
"private": true,
"type": "module"
}

View File

@@ -1,32 +0,0 @@
You are a GitHub issue resolution classifier.
Your job is to decide whether an open GitHub issue is:
- AUTO_CLOSE
- MANUAL_REVIEW
- KEEP_OPEN
Rules:
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
- a merged linked PR that clearly resolves the issue, or
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
5. Do not invent evidence.
6. Output valid JSON only.
Maintainer-authoritative roles:
- MEMBER
- OWNER
- COLLABORATOR
Workarounds vs. actual fixes:
- A WORKAROUND is when a user changes their own setup to avoid the problem (editing configs, using a different setting, manual SQL fixes, switching tools). Workarounds do NOT count as resolution — the underlying issue is still present in the product.
- An ACTUAL FIX is when a user reports the problem went away after upgrading to a specific version (e.g., "fixed after updating to v0.65.1") or after a specific PR was merged. This suggests the fix was shipped in the product itself.
- If only workarounds exist and no maintainer has confirmed a fix, classify as KEEP_OPEN.
- If a user reports an actual fix via a version upgrade but no maintainer confirmed it, classify as MANUAL_REVIEW (not AUTO_CLOSE).
Important:
- Later comments outweigh earlier ones.
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.

View File

@@ -1,80 +0,0 @@
{
"type": "object",
"additionalProperties": false,
"required": [
"decision",
"reason_code",
"confidence",
"hard_signals",
"contradictions",
"summary",
"close_comment",
"manual_review_note"
],
"properties": {
"decision": {
"type": "string",
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
},
"reason_code": {
"type": "string",
"enum": [
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed",
"likely_fixed_but_unconfirmed",
"still_open",
"unclear"
]
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"hard_signals": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"merged_pr",
"maintainer_comment",
"duplicate_reference",
"superseded_reference"
]
},
"url": { "type": "string" }
}
}
},
"contradictions": {
"type": "array",
"items": {
"type": "object",
"additionalProperties": false,
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"reporter_still_broken",
"later_unresolved_comment",
"ambiguous_pr_link",
"other"
]
},
"url": { "type": "string" }
}
}
},
"summary": { "type": "string" },
"close_comment": { "type": "string" },
"manual_review_note": { "type": "string" }
}
}

View File

@@ -1,193 +0,0 @@
import fs from "node:fs/promises";
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
const ghHeaders = {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
// Use PROJECT_PAT for project board operations, fall back to GH_TOKEN
const projectHeaders = {
Authorization: `Bearer ${process.env.PROJECT_PAT || process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url, method = "GET", body) {
const res = await fetch(url, {
method,
headers: ghHeaders,
body: body ? JSON.stringify(body) : undefined
});
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.status === 204 ? null : res.json();
}
async function graphql(query, variables) {
const res = await fetch("https://api.github.com/graphql", {
method: "POST",
headers: projectHeaders,
body: JSON.stringify({ query, variables })
});
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
const json = await res.json();
if (json.errors) throw new Error(JSON.stringify(json.errors));
return json.data;
}
async function addLabel(owner, repo, issueNumber, labels) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
"POST",
{ labels }
);
}
async function addComment(owner, repo, issueNumber, body) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
"POST",
{ body }
);
}
async function closeIssue(owner, repo, issueNumber) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
"PATCH",
{ state: "closed", state_reason: "completed" }
);
}
async function getIssueNodeId(owner, repo, issueNumber) {
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
return issue.node_id;
}
async function addToProject(issueNodeId) {
const mutation = `
mutation($projectId: ID!, $contentId: ID!) {
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
item { id }
}
}
`;
try {
const data = await graphql(mutation, {
projectId: process.env.PROJECT_ID,
contentId: issueNodeId
});
return data.addProjectV2ItemById.item.id;
} catch (err) {
console.warn(`[WARN] Could not add to project (needs PAT with project scope): ${err.message}`);
return null;
}
}
async function setTextField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { text: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
async function setNumberField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: Float!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { number: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
async function addToProjectWithFields(owner, repo, d) {
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
const itemId = await addToProject(issueNodeId);
if (itemId) {
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
await setNumberField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, d.model.confidence);
}
if (process.env.PROJECT_REASON_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
}
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
}
// Linked pull requests field is a built-in type that can't be set via API
// GitHub auto-populates it from issue cross-references
if (process.env.PROJECT_REPO_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
}
console.log(` → Added to project board`);
}
}
for (const d of decisions) {
const [owner, repo] = d.repository.split("/");
if (d.final_decision === "KEEP_OPEN") {
console.log(`#${d.issue_number} → KEEP_OPEN (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
continue;
}
if (dryRun) {
console.log(`[DRY RUN] #${d.issue_number}${d.final_decision} (confidence: ${d.model.confidence}, reason: ${d.model.reason_code})`);
// In dry-run: populate project board but don't touch issues
if (d.final_decision === "MANUAL_REVIEW" || d.final_decision === "AUTO_CLOSE") {
await addToProjectWithFields(owner, repo, d);
}
continue;
}
if (d.final_decision === "AUTO_CLOSE") {
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
await addComment(owner, repo, d.issue_number, d.model.close_comment);
await closeIssue(owner, repo, d.issue_number);
await addToProjectWithFields(owner, repo, d);
}
if (d.final_decision === "MANUAL_REVIEW") {
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
await addToProjectWithFields(owner, repo, d);
await addComment(
owner,
repo,
d.issue_number,
d.model.manual_review_note ||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
);
}
}

View File

@@ -1,259 +0,0 @@
import fs from "node:fs/promises";
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
const systemPrompt = await fs.readFile("prompts/issue-resolution-system.txt", "utf8");
const outputSchema = JSON.parse(await fs.readFile("schemas/issue-resolution-output.json", "utf8"));
function isMaintainerRole(role) {
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
}
function preScore(candidate) {
let score = 0;
const hardSignals = [];
const contradictions = [];
for (const t of candidate.timeline) {
const sourceIssue = t.source?.issue;
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
hardSignals.push({
type: "merged_pr",
url: sourceIssue.html_url
});
score += 40; // provisional until PR merged state is verified
}
if (["referenced", "connected"].includes(t.event)) {
score += 10;
}
}
for (const c of candidate.comments) {
const body = c.body.toLowerCase();
if (
isMaintainerRole(c.author_association) &&
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
) {
score += 25;
hardSignals.push({
type: "maintainer_comment",
url: c.html_url
});
}
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
score -= 50;
contradictions.push({
type: "later_unresolved_comment",
url: c.html_url
});
}
}
return { score, hardSignals, contradictions };
}
// GitHub Models gpt-4o has an 8000 token input limit.
// Reserve ~2000 tokens for system prompt + response overhead.
// 1 token ~= 4 chars, so cap user message at ~24000 chars.
const MAX_USER_MESSAGE_CHARS = 24000;
function truncate(text, maxChars) {
if (text.length <= maxChars) return text;
return text.slice(0, maxChars) + "\n\n[... truncated due to length]";
}
function buildUserMessage(candidate, pre) {
const { issue, comments, timeline } = candidate;
const commentBlock = comments
.map((c) => `[${c.author_association}] ${c.user} (${c.created_at}):\n${c.body}`)
.join("\n---\n");
const timelineBlock = timeline
.filter((t) => ["cross-referenced", "referenced", "connected", "closed", "reopened"].includes(t.event))
.map((t) => {
let line = `${t.event} (${t.created_at})`;
if (t.source?.issue?.html_url) line += `${t.source.issue.html_url}`;
if (t.source?.issue?.pull_request?.html_url) line += ` (PR: ${t.source.issue.pull_request.html_url})`;
return line;
})
.join("\n");
const sections = [
`## Issue #${issue.number}: ${issue.title}`,
`URL: ${issue.html_url}`,
`Created: ${issue.created_at} | Updated: ${issue.updated_at}`,
`Labels: ${issue.labels.join(", ") || "none"}`,
"",
"### Body",
truncate(issue.body || "(empty)", 4000),
"",
"### Comments",
commentBlock || "(none)",
"",
"### Timeline events",
timelineBlock || "(none)",
];
if (candidate.linked_prs?.length) {
sections.push("");
sections.push("### Linked PRs (verified state)");
for (const pr of candidate.linked_prs) {
const status = pr.merged ? `MERGED (${pr.merged_at})` : pr.state.toUpperCase();
sections.push(`- PR #${pr.number}: ${pr.title}${status}${pr.url}`);
}
}
if (pre.hardSignals.length || pre.contradictions.length) {
sections.push("");
sections.push("### Automated evidence scan");
for (const s of pre.hardSignals) {
sections.push(`- SIGNAL: ${s.type}${s.url}`);
}
for (const c of pre.contradictions) {
sections.push(`- CONTRADICTION: ${c.type}${c.url}`);
}
}
return truncate(sections.join("\n"), MAX_USER_MESSAGE_CHARS);
}
const MODEL = "gpt-4o-mini";
const MAX_RETRIES = 5;
function sleep(ms) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
async function callGitHubModel(candidate, pre) {
const body = JSON.stringify({
model: MODEL,
messages: [
{ role: "system", content: systemPrompt },
{ role: "user", content: buildUserMessage(candidate, pre) },
],
response_format: {
type: "json_schema",
json_schema: {
name: "issue_resolution",
strict: true,
schema: outputSchema,
},
},
temperature: 0.1,
});
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
const res = await fetch("https://models.inference.ai.azure.com/chat/completions", {
method: "POST",
headers: {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
"Content-Type": "application/json",
},
body,
});
if (res.status === 429) {
const retryAfter = Number(res.headers.get("retry-after")) || 30;
if (retryAfter > 120) {
console.warn(` [QUOTA EXHAUSTED] API wants ${retryAfter}s wait — skipping remaining issues.`);
return null;
}
console.warn(` [RATE LIMITED] Waiting ${retryAfter}s (attempt ${attempt + 1}/${MAX_RETRIES})...`);
await sleep(retryAfter * 1000);
continue;
}
if (!res.ok) {
const text = await res.text();
throw new Error(`GitHub Models ${res.status}: ${text}`);
}
const data = await res.json();
return JSON.parse(data.choices[0].message.content);
}
throw new Error(`GitHub Models: exceeded ${MAX_RETRIES} retries due to rate limiting`);
}
function enforcePolicy(modelOut, pre) {
const approvedReasons = new Set([
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed"
]);
const hasHardSignal =
(modelOut.hard_signals || []).some(s =>
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
) || pre.hardSignals.length > 0;
const hasContradiction =
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
// Only auto-close with very strict criteria
if (
modelOut.decision === "AUTO_CLOSE" &&
modelOut.confidence >= 0.97 &&
approvedReasons.has(modelOut.reason_code) &&
hasHardSignal &&
!hasContradiction
) {
return "AUTO_CLOSE";
}
// Downgrade AUTO_CLOSE that didn't pass the gate
if (modelOut.decision === "AUTO_CLOSE") {
return "MANUAL_REVIEW";
}
// Otherwise trust the model
return modelOut.decision;
}
console.log(`Classifying ${candidates.length} candidates with ${MODEL}...\n`);
// 15 req/min limit → 1 request every 4s. Use 4.5s for safety margin.
const PACE_MS = 4500;
let lastRequestTime = 0;
async function paced(fn) {
const elapsed = Date.now() - lastRequestTime;
if (elapsed < PACE_MS) await sleep(PACE_MS - elapsed);
lastRequestTime = Date.now();
return fn();
}
const decisions = [];
for (const candidate of candidates) {
const pre = preScore(candidate);
const modelOut = await paced(() => callGitHubModel(candidate, pre));
if (modelOut === null) {
console.warn(`\nQuota exhausted after ${decisions.length} issues. Writing partial results.`);
break;
}
const finalDecision = enforcePolicy(modelOut, pre);
decisions.push({
repository: candidate.repository,
issue_number: candidate.issue.number,
issue_url: candidate.issue.html_url,
title: candidate.issue.title,
pre_score: pre.score,
final_decision: finalDecision,
model: modelOut
});
console.log(
`#${candidate.issue.number} | pre_score: ${pre.score} | model: ${modelOut.decision} @ ${modelOut.confidence} | final: ${finalDecision} | ${modelOut.reason_code}`
);
}
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
console.log(`\nWrote ${decisions.length} decisions to decisions.json`);

View File

@@ -1,123 +0,0 @@
import fs from "node:fs/promises";
const token = process.env.GH_TOKEN;
const repo = process.env.REPO; // "owner/repo"
const maxIssues = Number(process.env.MAX_ISSUES) || 100;
const headers = {
Authorization: `Bearer ${token}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url) {
const res = await fetch(url, { headers });
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.json();
}
async function restSafe(url) {
const res = await fetch(url, { headers });
if (!res.ok) return null;
return res.json();
}
async function paginate(url, max) {
const items = [];
let page = 1;
while (items.length < max) {
const perPage = Math.min(100, max - items.length);
const sep = url.includes("?") ? "&" : "?";
const batch = await rest(`${url}${sep}per_page=${perPage}&page=${page}`);
if (!batch.length) break;
items.push(...batch);
page++;
}
return items.slice(0, max);
}
console.log(`Fetching up to ${maxIssues} open issues from ${repo}...`);
const issues = await paginate(
`https://api.github.com/repos/${repo}/issues?state=open&sort=updated&direction=desc`,
maxIssues
);
// Filter out pull requests (GitHub API returns PRs as issues too)
const realIssues = issues.filter((i) => !i.pull_request);
console.log(`Found ${realIssues.length} open issues (excluded PRs).`);
const candidates = [];
for (const issue of realIssues) {
const [comments, timeline] = await Promise.all([
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/comments?per_page=100`),
rest(`https://api.github.com/repos/${repo}/issues/${issue.number}/timeline?per_page=100`),
]);
candidates.push({
repository: repo,
issue: {
number: issue.number,
html_url: issue.html_url,
title: issue.title,
body: issue.body,
created_at: issue.created_at,
updated_at: issue.updated_at,
labels: issue.labels.map((l) => l.name),
},
comments: comments.map((c) => ({
body: c.body,
author_association: c.author_association,
html_url: c.html_url,
created_at: c.created_at,
user: c.user?.login,
})),
timeline: timeline.map((t) => ({
event: t.event,
created_at: t.created_at,
source: t.source
? {
issue: {
html_url: t.source.issue?.html_url,
pull_request: t.source.issue?.pull_request
? { html_url: t.source.issue.pull_request.html_url }
: undefined,
},
}
: undefined,
})),
linked_prs: [],
});
// Fetch merge status for cross-referenced PRs
const prUrls = new Set();
for (const t of timeline) {
const prHtml = t.source?.issue?.pull_request?.html_url;
if (t.event === "cross-referenced" && prHtml) {
prUrls.add(prHtml);
}
}
const candidate = candidates[candidates.length - 1];
for (const prHtml of prUrls) {
// Extract owner/repo and PR number from URL like https://github.com/owner/repo/pull/123
const match = prHtml.match(/github\.com\/([^/]+\/[^/]+)\/pull\/(\d+)/);
if (!match) continue;
const [, prRepo, prNum] = match;
const pr = await restSafe(`https://api.github.com/repos/${prRepo}/pulls/${prNum}`);
if (!pr) continue;
candidate.linked_prs.push({
number: pr.number,
title: pr.title,
url: prHtml,
state: pr.state,
merged: pr.merged || false,
merged_at: pr.merged_at,
});
}
console.log(` #${issue.number}${comments.length} comments, ${timeline.length} timeline events, ${candidate.linked_prs.length} linked PRs`);
}
await fs.writeFile("candidates.json", JSON.stringify(candidates, null, 2));
console.log(`Wrote ${candidates.length} candidates to candidates.json`);

View File

@@ -1,65 +0,0 @@
name: issue-resolution-triage
on:
push:
branches: [github-issue-resolver]
workflow_dispatch:
inputs:
dry_run:
description: "If true, do not close issues"
required: false
default: "true"
max_issues:
description: "How many issues to process"
required: false
default: "100"
schedule:
- cron: "17 2 * * *"
permissions:
contents: read
issues: write
pull-requests: read
models: read
# todo: remove hardcoded values
jobs:
triage:
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PROJECT_PAT: ${{ secrets.PROJECT_PAT }}
DRY_RUN: "true"
MAX_ISSUES: "100"
REPO: ${{ github.repository }}
PROJECT_ID: "PVT_kwDOBfz4Jc4BVeWR"
PROJECT_STATUS_FIELD_ID: "PVTSSF_lADOBfz4Jc4BVeWRzhQ56sU"
PROJECT_CONFIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ57x4"
PROJECT_REASON_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Lg"
PROJECT_EVIDENCE_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ5-Pw"
PROJECT_LINKED_PR_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ56sc"
PROJECT_REPO_FIELD_ID: "PVTF_lADOBfz4Jc4BVeWRzhQ56sk"
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: "a55a2be9"
defaults:
run:
working-directory: .github/issue-resolution
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
- run: node scripts/fetch-candidates.mjs
- run: node scripts/classify-candidates.mjs
- run: node scripts/apply-decisions.mjs
- uses: actions/upload-artifact@v4
if: always()
with:
name: triage-results
path: |
.github/issue-resolution/candidates.json
.github/issue-resolution/decisions.json

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.4"
SIGN_PIPE_VER: "v0.1.2"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)

View File

@@ -1,10 +1,7 @@
package dns
import (
"context"
"fmt"
"math"
"net"
"slices"
"strconv"
"strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 {
return
}
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order
for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) {
continue
}
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime))
}
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
}
}
}
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
})
}
}
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -2,83 +2,40 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
const dnsTimeout = 5 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question]*cachedRecord
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question][]dns.RR),
}
}
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
}
}
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return err
}
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
now := time.Now()
m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
// untouched on error. Caller holds m.mutex.
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
delete(m.records, qA)
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains())
}
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
ctx: ctx,

View File

@@ -944,12 +944,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err)
}
urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
e.relayManager.UpdateServerURLs(update.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.

View File

@@ -7,8 +7,7 @@ import (
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func IsForceRelayed() bool {
@@ -17,28 +16,3 @@ func IsForceRelayed() bool {
}
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
// are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
return "oidc", applyOIDCDefaults(connType, config)
default:
return connType, config
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
}
return augmented

View File

@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
}
return encodeConnectorConfig(oidcConfig)
}
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -15,9 +16,11 @@ import (
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -55,6 +58,13 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
}
globalStart := time.Now()
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -258,6 +318,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
@@ -307,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -378,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
}
account.InjectProxyPolicies(ctx)
@@ -412,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -427,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
@@ -563,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
@@ -573,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
@@ -607,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -649,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err
}
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {

View File

@@ -12,6 +12,9 @@ import (
)
const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0"

View File

@@ -408,7 +408,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMapFromComponents(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -1171,6 +1171,11 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t)
}
@@ -1226,6 +1231,11 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
@@ -1264,6 +1274,11 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t)
}
@@ -1317,6 +1332,11 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t)
}
@@ -1377,6 +1397,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
@@ -1608,6 +1633,75 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
}
func TestAccount_GetRoutesToSync(t *testing.T) {
_, prefix, err := route.ParseNetwork("192.168.64.0/24")
if err != nil {
t.Fatal(err)
}
_, prefix2, err := route.ParseNetwork("192.168.0.0/24")
if err != nil {
t.Fatal(err)
}
account := &types.Account{
Peers: map[string]*nbpeer.Peer{
"peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}},
},
Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}},
Routes: map[route.ID]*route.Route{
"route-1": {
ID: "route-1",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-1",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-2": {
ID: "route-2",
Network: prefix2,
NetID: "network-2",
Description: "network-2",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
"route-3": {
ID: "route-3",
Network: prefix,
NetID: "network-1",
Description: "network-1",
Peer: "peer-2",
NetworkType: 0,
Masquerade: false,
Metric: 999,
Enabled: true,
Groups: []string{"group1"},
},
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
for _, r := range routes {
routeIDs[r.ID] = struct{}{}
}
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
func TestAccount_Copy(t *testing.T) {
account := &types.Account{
Id: "account1",
@@ -1730,7 +1824,9 @@ func TestAccount_Copy(t *testing.T) {
AccountID: "account1",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)

View File

@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -274,7 +274,7 @@ func identityProviderToConnectorConfig(idpConfig *types.IdentityProvider) *dex.C
}
// generateIdentityProviderID generates a unique ID for an identity provider.
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft, adfs),
// For specific provider types (okta, zitadel, entra, google, pocketid, microsoft),
// the ID is prefixed with the type name. Generic OIDC providers get no prefix.
func generateIdentityProviderID(idpType types.IdentityProviderType) string {
id := xid.New().String()
@@ -296,8 +296,6 @@ func generateIdentityProviderID(idpType types.IdentityProviderType) string {
return "authentik-" + id
case types.IdentityProviderTypeKeycloak:
return "keycloak-" + id
case types.IdentityProviderTypeADFS:
return "adfs-" + id
default:
// Generic OIDC - no prefix
return id

View File

@@ -179,6 +179,11 @@ func TestAccountManager_GetNetworkMap(t *testing.T) {
testGetNetworkMapGeneral(t)
}
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) {
manager, _, err := createManager(t)
if err != nil {
@@ -1011,6 +1016,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t)
}
@@ -1590,6 +1600,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
t.Setenv(network_map.EnvNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}

View File

@@ -2,8 +2,10 @@ package server
import (
"context"
"fmt"
"net"
"net/netip"
"sort"
"testing"
"time"
@@ -1838,6 +1840,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
},
}
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("check applied policies for the route", func(t *testing.T) {
route1 := account.Routes["route1"]
policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
@@ -1851,6 +1858,116 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) {
policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
assert.Len(t, policies, 0)
})
t.Run("check peer routes firewall rules", func(t *testing.T) {
routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
assert.Len(t, routesFirewallRules, 4)
expectedRoutesFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "route1:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "route1:peerA",
},
}
additionalFirewallRule := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "tcp",
Port: 80,
RouteID: "route4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.168.10.0/16",
Protocol: "all",
RouteID: "route4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...)))
// peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
assert.Len(t, routesFirewallRules, 2)
for _, rule := range expectedRoutesFirewallRules {
rule.RouteID = "route1:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerE is a single routing peer for route 2 and route 3
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
assert.Len(t, routesFirewallRules, 3)
expectedRoutesFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: existingNetwork.String(),
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "route2",
},
{
SourceRanges: []string{"0.0.0.0/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
{
SourceRanges: []string{"::/0"},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "route3",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules))
// peerC is part of route1 distribution groups but should not receive the routes firewall rules
routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, routesFirewallRules, 0)
})
}
// orderList is a helper function to sort a list of strings
func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule {
for _, rule := range ruleList {
sort.Strings(rule.SourceRanges)
}
return ruleList
}
func TestRouteAccountPeersUpdate(t *testing.T) {
@@ -2548,6 +2665,11 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
},
}
validatedPeers := make(map[string]struct{})
for p := range account.Peers {
validatedPeers[p] = struct{}{}
}
t.Run("validate applied policies for different network resources", func(t *testing.T) {
// Test case: Resource1 is directly applied to the policy (policyResource1)
policies := account.GetPoliciesForNetworkResource("resource1")
@@ -2571,4 +2693,127 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) {
policies = account.GetPoliciesForNetworkResource("resource6")
assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups")
})
t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) {
resourcePoliciesMap := account.GetResourcePoliciesMap()
resourceRoutersMap := account.GetResourceRoutersMap()
_, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap)
firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 4)
assert.Len(t, sourcePeers, 5)
expectedFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 80,
RouteID: "resource2:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerCIp),
fmt.Sprintf(types.AllowedIPsFormat, peerHIp),
fmt.Sprintf(types.AllowedIPsFormat, peerBIp),
},
Action: "accept",
Destination: "192.168.0.0/16",
Protocol: "all",
Port: 320,
RouteID: "resource2:peerA",
},
}
additionalFirewallRules := []*types.RouteFirewallRule{
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerJIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "tcp",
Port: 80,
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
{
SourceRanges: []string{
fmt.Sprintf(types.AllowedIPsFormat, peerKIp),
},
Action: "accept",
Destination: "192.0.2.0/32",
Protocol: "all",
Domains: domain.List{"example.com"},
IsDynamic: true,
RouteID: "resource4:peerA",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...)))
// peerD is also the routing peer for resource2
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 2)
for _, rule := range expectedFirewallRules {
rule.RouteID = "resource2:peerD"
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
assert.Len(t, sourcePeers, 3)
// peerE is a single routing peer for resource1 and resource3
// PeerE should only receive rules for resource1 since resource3 has no applied policy
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 2)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
Action: "accept",
Destination: "10.10.10.0/24",
Protocol: "tcp",
PortRange: types.RulePortRange{Start: 80, End: 350},
RouteID: "resource1:peerE",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
// peerC is part of distribution groups for resource2 but should not receive the firewall rules
firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
assert.Len(t, firewallRules, 0)
// peerL is the single routing peer for resource5
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap)
assert.Len(t, firewallRules, 1)
assert.Len(t, sourcePeers, 1)
expectedFirewallRules = []*types.RouteFirewallRule{
{
SourceRanges: []string{"100.65.29.67/32"},
Action: "accept",
Destination: "10.12.12.1/32",
Protocol: "tcp",
Port: 8080,
RouteID: "resource5:peerL",
},
}
assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules))
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 0)
_, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap)
assert.Len(t, routes, 1)
assert.Len(t, sourcePeers, 2)
})
}

View File

@@ -1196,6 +1196,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
account.InitOnce()
return &account, nil
}
@@ -1634,6 +1635,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
if sExtraIntegratedValidatorGroups.Valid {
_ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups)
}
account.InitOnce()
return &account, nil
}

View File

@@ -8,6 +8,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
@@ -26,6 +27,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -108,9 +110,16 @@ type Account struct {
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
ReverseProxyFreeDomainNonce string
}
func (a *Account) InitOnce() {
a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only
type PrimaryAccountInfo struct {
IsDomainPrimaryAccount bool
@@ -146,6 +155,108 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
o.SignupFormPending == onboarding.SignupFormPending
}
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership
func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map
func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
// If the given is not a routing peer, then the lists are empty.
func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peer := a.GetPeer(peerID)
if peer == nil {
log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id)
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route, id string) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
if r.Enabled {
r.Peer = peer.Key
enabledRoutes = append(enabledRoutes, r)
return
}
disabledRoutes = append(disabledRoutes, r)
}
for _, r := range a.Routes {
for _, groupID := range r.PeerGroups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id)
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := r.Copy()
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map
takeRoute(newPeerRoute, id)
break
}
}
if r.Peer == peerID {
takeRoute(r.Copy(), peerID)
}
}
return enabledRoutes, disabledRoutes
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route {
var routes []*route.Route
@@ -165,6 +276,106 @@ func (a *Account) GetGroup(groupID string) *Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
if _, ok := validatedPeersMap[peerID]; !ok {
return &NetworkMap{
Network: a.Network.Copy(),
}
}
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration)
if a.Settings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies)
}
peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: peersCustomZone.Domain,
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
nm := &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: a.Network.Copy(),
Routes: slices.Concat(networkResourcesRoutes, routesUpdate),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
"peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d",
a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules))
}
}
return nm
}
func (a *Account) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer,
@@ -210,6 +421,39 @@ func (a *Account) addNetworksRoutingPeers(
return peersToConnect
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
groupList := account.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range account.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
if !peerIsNameserver(account.GetPeer(peerID), nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
// peerIsNameserver returns true if the peer is a nameserver for a nsGroup
func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peer.IP.Equal(ns.IP.AsSlice()) {
return true
}
}
return false
}
func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) {
for _, peer := range account.Peers {
label, err := GetPeerHostLabel(peer.Name, peerLabels)
@@ -556,6 +800,19 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.GetPeerGroups(peerID)
enabled := true
for _, groupID := range a.DNSSettings.DisabledManagementGroups {
_, found := peerGroups[groupID]
if found {
enabled = false
break
}
}
return enabled
}
func (a *Account) GetPeerGroups(peerID string) LookupMap {
groupList := make(LookupMap)
for groupID, group := range a.Groups {
@@ -684,6 +941,8 @@ func (a *Account) Copy() *Account {
NetworkResources: networkResources,
Services: services,
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
Domains: domains,
}
}
@@ -1045,6 +1304,31 @@ func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks {
return nil
}
// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
for _, route := range enabledRoutes {
// If no access control groups are specified, accept all traffic.
if len(route.AccessControlGroups) == 0 {
defaultPermit := getDefaultPermit(route)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := a.getDistributionGroupsPeers(route)
for _, accessGroup := range route.AccessControlGroups {
policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup})
rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
@@ -1103,6 +1387,50 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID
return distributionGroupPeers
}
func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range route.Groups {
group := a.Groups[id]
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if route.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: route.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: route.Domains,
IsDynamic: route.IsDynamic(),
RouteID: route.ID,
}
rules = append(rules, &rule)
// dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
if route.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
// and returns a list of policies that have rules with destinations matching the specified groups.
func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
@@ -1180,6 +1508,65 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
return resourcePolicies
}
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := routers[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...)
}
}
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
@@ -1271,6 +1658,22 @@ func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string {
return result
}
// getNetworkResourcesRoutes convert the network resources list to routes list.
func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route {
resourceAppliedPolicies := resourcePolicies[resource.ID]
var routes []*route.Route
// distribute the resource routes only if there is policy applied to it
if len(resourceAppliedPolicies) > 0 {
peer := a.GetPeer(peerId)
if peer != nil {
routes = append(routes, resource.ToRoute(peer, router))
}
}
return routes
}
func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter {
routers := make(map[string]map[string]*routerTypes.NetworkRouter)

View File

@@ -4,6 +4,8 @@ import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"testing"
"github.com/miekg/dns"
@@ -17,6 +19,7 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
@@ -448,6 +451,402 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
require.Len(t, result, 0)
}
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
Enabled: true,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
Enabled: true,
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) {
account := getBasicAccountsWithResource()
account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}}
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router2Id",
NetworkID: network1ID,
AccountID: accID,
Peer: "router2Id",
})
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
}
func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
tests := []struct {
name string

View File

@@ -0,0 +1,47 @@
package types
import (
"context"
"sync"
)
type Holder struct {
mu sync.RWMutex
accounts map[string]*Account
}
func NewHolder() *Holder {
return &Holder{
accounts: make(map[string]*Account),
}
}
func (h *Holder) GetAccount(id string) *Account {
h.mu.RLock()
defer h.mu.RUnlock()
return h.accounts[id]
}
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
a := h.accounts[account.Id]
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
return
}
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(ctx, id)
if err != nil {
return nil, err
}
h.accounts[id] = account
return account, nil
}

View File

@@ -39,8 +39,6 @@ const (
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
// IdentityProviderTypeKeycloak is the Keycloak identity provider
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
// IdentityProviderTypeADFS is the Microsoft AD FS identity provider
IdentityProviderTypeADFS IdentityProviderType = "adfs"
)
// IdentityProvider represents an identity provider configuration
@@ -114,8 +112,7 @@ func (t IdentityProviderType) IsValid() bool {
switch t {
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak,
IdentityProviderTypeADFS:
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
return true
}
return false

View File

@@ -0,0 +1,67 @@
package types
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
if a.NetworkMapCache != nil {
return
}
a.nmapInitOnce.Do(func() {
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
})
}
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}
func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
}
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
}
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
}
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.UpdatePeer(peer)
}
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}

View File

@@ -0,0 +1,592 @@
package types
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
t.Fatal("GetPeerNetworkMapComponents returned nil")
}
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
if newNetworkMap == nil {
t.Fatal("CalculateNetworkMapFromComponents returned nil")
}
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
}
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
goldenDir := filepath.Join("testdata", "comparison")
err = os.MkdirAll(goldenDir, 0755)
require.NoError(t, err)
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
require.NoError(t, err, "error writing legacy golden file")
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
err = os.WriteFile(newGoldenPath, newJSON, 0644)
require.NoError(t, err, "error writing components golden file")
componentsPath := filepath.Join(goldenDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err, "error writing components golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s",
legacyGoldenPath, newGoldenPath)
t.Logf("✅ NetworkMaps are identical")
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
t.Logf(" Components NetworkMap: %s", newGoldenPath)
}
func normalizeAndSortNetworkMap(nm *NetworkMap) {
if nm == nil {
return
}
sort.Slice(nm.Peers, func(i, j int) bool {
return nm.Peers[i].ID < nm.Peers[j].ID
})
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
})
sort.Slice(nm.Routes, func(i, j int) bool {
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
})
sort.Slice(nm.FirewallRules, func(i, j int) bool {
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
}
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
}
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
}
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
}
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
})
for i := range nm.RoutesFirewallRules {
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
}
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
}
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
}
for k := 0; k < minLen; k++ {
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
}
}
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
}
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
}
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
}
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
}
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
})
if nm.DNSConfig.CustomZones != nil {
for i := range nm.DNSConfig.CustomZones {
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
})
}
}
if len(nm.DNSConfig.NameServerGroups) != 0 {
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
})
}
}
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
t.Helper()
if legacy.Network.Serial != current.Network.Serial {
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
}
if len(legacy.Peers) != len(current.Peers) {
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
}
legacyPeerIDs := make(map[string]bool)
for _, p := range legacy.Peers {
legacyPeerIDs[p.ID] = true
}
for _, p := range current.Peers {
if !legacyPeerIDs[p.ID] {
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
}
}
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
}
if len(legacy.FirewallRules) != len(current.FirewallRules) {
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
}
if len(legacy.Routes) != len(current.Routes) {
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
}
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
}
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
}
}
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60"
expiredPeerID = "peer-98"
offlinePeerID = "peer-99"
routingPeerID = "peer-95"
testAccountID = "account-comparison-test"
)
func createTestAccount() *Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
}
policies := []*Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
account := &Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Network: &Network{
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func BenchmarkLegacyNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
}
}
func BenchmarkComponentsNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func BenchmarkComponentsCreation(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
}
}
func BenchmarkCalculationFromComponents(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}

View File

@@ -19,6 +19,8 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
type NetworkMapComponents struct {
PeerID string

View File

@@ -1,787 +0,0 @@
package types_test
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
func networkMapFromComponents(t *testing.T, account *types.Account, peerID string, validatedPeers map[string]struct{}) *types.NetworkMap {
t.Helper()
return account.GetPeerNetworkMapFromComponents(
context.Background(),
peerID,
account.GetPeersCustomZone(context.Background(), "netbird.io"),
nil,
validatedPeers,
account.GetResourcePoliciesMap(),
account.GetResourceRoutersMap(),
nil,
account.GetActiveGroupUsers(),
)
}
func allPeersValidated(account *types.Account, excludePeerIDs ...string) map[string]struct{} {
excludeSet := make(map[string]struct{}, len(excludePeerIDs))
for _, id := range excludePeerIDs {
excludeSet[id] = struct{}{}
}
validated := make(map[string]struct{}, len(account.Peers))
for id := range account.Peers {
if _, excluded := excludeSet[id]; !excluded {
validated[id] = struct{}{}
}
}
return validated
}
func peerIDs(peers []*nbpeer.Peer) []string {
ids := make([]string, len(peers))
for i, p := range peers {
ids[i] = p.ID
}
return ids
}
func TestNetworkMapComponents_RegularPeerConnectivity(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.NotNil(t, nm)
assert.Contains(t, peerIDs(nm.Peers), "peer-dst-1", "should see peer from destination group via bidirectional policy")
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "should see router peer via resource policy")
assert.NotContains(t, peerIDs(nm.Peers), "peer-src-1", "should not see itself")
assert.Empty(t, nm.OfflinePeers, "no expired peers expected")
}
func TestNetworkMapComponents_IntraGroupConnectivity(t *testing.T) {
account := createComponentTestAccount()
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-intra-src", Name: "Intra-source connectivity", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-intra-src", Name: "src <-> src", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-src"},
}},
})
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Contains(t, peerIDs(nm.Peers), "peer-src-2", "should see peer from same group with intra-group policy")
}
func TestNetworkMapComponents_FirewallRules(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
require.NotEmpty(t, nm.FirewallRules, "firewall rules should be generated")
var hasAcceptAll bool
for _, rule := range nm.FirewallRules {
if rule.Protocol == string(types.PolicyRuleProtocolALL) && rule.Action == string(types.PolicyTrafficActionAccept) {
hasAcceptAll = true
}
}
assert.True(t, hasAcceptAll, "should have an accept-all firewall rule from the base policy")
}
func TestNetworkMapComponents_LoginExpiration(t *testing.T) {
account := createComponentTestAccount()
account.Settings.PeerLoginExpirationEnabled = true
account.Settings.PeerLoginExpiration = 1 * time.Hour
expiredTime := time.Now().Add(-2 * time.Hour)
account.Peers["peer-dst-1"].LoginExpirationEnabled = true
account.Peers["peer-dst-1"].LastLogin = &expiredTime
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Contains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "expired peer should be in OfflinePeers")
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "expired peer should NOT be in active Peers")
}
func TestNetworkMapComponents_InvalidatedPeerExcluded(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account, "peer-dst-1")
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.NotContains(t, peerIDs(nm.Peers), "peer-dst-1", "non-validated peer should be excluded")
assert.NotContains(t, peerIDs(nm.OfflinePeers), "peer-dst-1", "non-validated peer should not be in offline peers either")
}
func TestNetworkMapComponents_NonValidatedTargetPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account, "peer-src-1")
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Peers, "non-validated target peer should get empty network map")
assert.Empty(t, nm.FirewallRules)
}
func TestNetworkMapComponents_NetworkResourceRoutes_SourcePeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasResourceRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" {
hasResourceRoute = true
break
}
}
assert.True(t, hasResourceRoute, "source peer should receive route to network resource via router")
assert.Contains(t, peerIDs(nm.Peers), "peer-router-1", "source peer should see the routing peer")
}
func TestNetworkMapComponents_NetworkResourceRoutes_RouterPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
var hasResourceRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" {
hasResourceRoute = true
break
}
}
assert.True(t, hasResourceRoute, "router peer should receive network resource route")
assert.NotEmpty(t, nm.RoutesFirewallRules, "router peer should have route firewall rules for the resource")
}
func TestNetworkMapComponents_NetworkResourceRoutes_UnrelatedPeer(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "unrelated peer should not receive network resource route")
}
}
func TestNetworkMapComponents_NetworkResource_WithPostureCheck(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.PostureChecks = []*posture.Checks{
{ID: "pc-version", Name: "Version check", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-posture-resource", Name: "Posture resource access", Enabled: true, AccountID: account.Id,
SourcePostureChecks: []string{"pc-version"},
Rules: []*types.PolicyRule{{
ID: "rule-posture-resource", Name: "Posture -> Resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-guarded"},
}},
})
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-guarded", NetworkID: "net-guarded", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.1.1/32"), Address: "10.200.1.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-guarded", Name: "Guarded Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-guarded", NetworkID: "net-guarded", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
})
t.Run("peer passes posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasGuardedRoute bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.1.1/32" {
hasGuardedRoute = true
}
}
assert.True(t, hasGuardedRoute, "peer passing posture check should get guarded resource route")
})
t.Run("peer fails posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.1.1/32", r.Network.String(), "peer failing posture check should NOT get guarded resource route")
}
})
}
func TestNetworkMapComponents_NetworkResource_MultiplePostureChecks(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.PostureChecks = []*posture.Checks{
{ID: "pc-version", Name: "Version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
}},
{ID: "pc-os", Name: "OS check", Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{Linux: &posture.MinKernelVersionCheck{MinKernelVersion: "5.0"}},
}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-multi-posture", Name: "Multi posture", Enabled: true, AccountID: account.Id,
SourcePostureChecks: []string{"pc-version", "pc-os"},
Rules: []*types.PolicyRule{{
ID: "rule-multi-posture", Name: "Multi posture rule", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-strict"},
}},
})
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-strict", NetworkID: "net-strict", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.2.1/32"), Address: "10.200.2.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-strict", Name: "Strict Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-strict", NetworkID: "net-strict", Peer: "peer-router-1", Enabled: true, AccountID: account.Id,
})
t.Run("passes both posture checks", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
account.Peers["peer-src-1"].Meta.GoOS = "linux"
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var found bool
for _, r := range nm.Routes {
if r.Network.String() == "10.200.2.1/32" {
found = true
}
}
assert.True(t, found, "peer passing both checks should get resource route")
})
t.Run("fails version posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.20.0"
account.Peers["peer-src-1"].Meta.KernelVersion = "6.1.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing version check should NOT get resource route")
}
})
t.Run("fails OS posture check", func(t *testing.T) {
account.Peers["peer-src-1"].Meta.WtVersion = "0.35.0"
account.Peers["peer-src-1"].Meta.KernelVersion = "4.0.0"
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.2.1/32", r.Network.String(), "peer failing OS check should NOT get resource route")
}
})
}
func TestNetworkMapComponents_RouterPeerFirewallRules(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
var resourceFWRules []*types.RouteFirewallRule
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "10.200.0.1/32" {
resourceFWRules = append(resourceFWRules, rule)
}
}
assert.NotEmpty(t, resourceFWRules, "router should have firewall rules for the network resource")
var hasSourcePeerIP bool
for _, rule := range resourceFWRules {
for _, sr := range rule.SourceRanges {
if sr == account.Peers["peer-src-1"].IP.String()+"/32" || sr == account.Peers["peer-src-2"].IP.String()+"/32" {
hasSourcePeerIP = true
}
}
}
assert.True(t, hasSourcePeerIP, "resource firewall rules should include source peer IPs")
}
func TestNetworkMapComponents_DNSManagement(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
t.Run("peer in DNS-enabled group", func(t *testing.T) {
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.True(t, nm.DNSConfig.ServiceEnable, "peer in non-disabled group should have DNS enabled")
})
t.Run("peer in DNS-disabled group", func(t *testing.T) {
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.False(t, nm.DNSConfig.ServiceEnable, "peer in DNS-disabled group should have DNS disabled")
})
}
func TestNetworkMapComponents_NameServerGroups(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.True(t, nm.DNSConfig.ServiceEnable)
var hasNSGroup bool
for _, ns := range nm.DNSConfig.NameServerGroups {
if ns.ID == "ns-main" {
hasNSGroup = true
}
}
assert.True(t, hasNSGroup, "peer in NS group should receive nameserver configuration")
}
func TestNetworkMapComponents_RoutesWithHADeduplication(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-ha-1"] = &route.Route{
ID: "route-ha-1", Network: netip.MustParsePrefix("172.16.0.0/16"),
Peer: account.Peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
}
account.Routes["route-ha-2"] = &route.Route{
ID: "route-ha-2", Network: netip.MustParsePrefix("172.16.0.0/16"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 200, AccountID: account.Id,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-src"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
haCount := 0
for _, r := range nm.Routes {
if r.Network.String() == "172.16.0.0/16" {
haCount++
}
}
assert.Equal(t, 1, haCount, "peer should only receive one route from HA group (not both, since it's a member of one)")
}
func TestNetworkMapComponents_RoutesFirewallRulesForAccessControl(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-acl"] = &route.Route{
ID: "route-acl", Network: netip.MustParsePrefix("192.168.100.0/24"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-dst"},
PeerGroups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasFWRule bool
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "192.168.100.0/24" {
hasFWRule = true
}
}
assert.True(t, hasFWRule, "routing peer should have firewall rules for route with access control groups")
}
func TestNetworkMapComponents_RoutesDefaultPermit(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Routes["route-open"] = &route.Route{
ID: "route-open", Network: netip.MustParsePrefix("10.99.0.0/16"),
Peer: account.Peers["peer-src-1"].Key, PeerID: "peer-src-1",
Enabled: true, Metric: 100, AccountID: account.Id,
Groups: []string{"group-src"},
PeerGroups: []string{"group-src"},
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasFWRule bool
for _, rule := range nm.RoutesFirewallRules {
if rule.Destination == "10.99.0.0/16" {
hasFWRule = true
}
}
assert.True(t, hasFWRule, "route without access control groups should have default permit firewall rules")
}
func TestNetworkMapComponents_SSHAuthorizedUsers(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Peers["peer-dst-1"].SSHEnabled = true
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh", Name: "SSH Access", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-ssh", Name: "SSH to dst", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.True(t, nm.EnableSSH, "SSH-enabled peer with matching policy should have EnableSSH")
}
func TestNetworkMapComponents_DisabledPolicyIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, p := range account.Policies {
p.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Peers, "with all policies disabled, peer should see no other peers")
assert.Empty(t, nm.FirewallRules)
}
func TestNetworkMapComponents_DisabledRouteIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, r := range account.Routes {
r.Enabled = false
}
for _, r := range account.NetworkResources {
r.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
assert.Empty(t, nm.Routes, "disabled routes should not appear in network map")
}
func TestNetworkMapComponents_DisabledNetworkResourceIgnored(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
for _, r := range account.NetworkResources {
r.Enabled = false
}
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.0.1/32", r.Network.String(), "disabled resource should not generate routes")
}
}
func TestNetworkMapComponents_BidirectionalPolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nmSrc := networkMapFromComponents(t, account, "peer-src-1", validated)
nmDst := networkMapFromComponents(t, account, "peer-dst-1", validated)
assert.Contains(t, peerIDs(nmSrc.Peers), "peer-dst-1", "src should see dst via bidirectional policy")
assert.Contains(t, peerIDs(nmDst.Peers), "peer-src-1", "dst should see src via bidirectional policy")
}
func TestNetworkMapComponents_DropPolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-drop", Name: "Drop traffic", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-drop", Name: "Drop src->dst", Enabled: true,
Action: types.PolicyTrafficActionDrop, Protocol: types.PolicyRuleProtocolTCP,
Ports: []string{"5432"},
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasDropRule bool
for _, rule := range nm.FirewallRules {
if rule.Action == string(types.PolicyTrafficActionDrop) && rule.Port == "5432" {
hasDropRule = true
}
}
assert.True(t, hasDropRule, "drop policy should generate drop firewall rule")
}
func TestNetworkMapComponents_PortRangePolicy(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.Peers["peer-src-1"].Meta.WtVersion = "0.50.0"
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-range", Name: "Port range", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-range", Name: "Range rule", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasRangeRule bool
for _, rule := range nm.FirewallRules {
if rule.PortRange.Start == 8080 && rule.PortRange.End == 8090 {
hasRangeRule = true
}
}
assert.True(t, hasRangeRule, "port range policy should generate corresponding firewall rule")
}
func TestNetworkMapComponents_MultipleNetworkResources(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-2", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.2/32"), Address: "10.200.0.2/32",
})
account.Groups["group-res2"] = &types.Group{ID: "group-res2", Name: "Resource 2 Group", Peers: []string{"peer-src-1", "peer-src-2"},
Resources: []types.Resource{{ID: "resource-2"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-res2", Name: "Resource 2 Policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-res2", Name: "Access Resource 2", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-2"},
}},
})
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
resourceRouteCount := 0
for _, r := range nm.Routes {
if r.Network.String() == "10.200.0.1/32" || r.Network.String() == "10.200.0.2/32" {
resourceRouteCount++
}
}
assert.Equal(t, 2, resourceRouteCount, "router should have routes for both network resources")
}
func TestNetworkMapComponents_DomainNetworkResource(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-domain", NetworkID: "net-1", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Domain, Domain: "api.example.com", Address: "api.example.com",
})
account.Groups["group-res-domain"] = &types.Group{
ID: "group-res-domain", Name: "Domain Resource Group",
Resources: []types.Resource{{ID: "resource-domain"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-domain", Name: "Domain resource policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-domain", Name: "Access domain resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-domain"},
}},
})
nm := networkMapFromComponents(t, account, "peer-src-1", validated)
var hasDomainRoute bool
for _, r := range nm.Routes {
if r.NetworkType == route.DomainNetwork && len(r.Domains) > 0 && r.Domains[0].SafeString() == "api.example.com" {
hasDomainRoute = true
}
}
assert.True(t, hasDomainRoute, "source peer should receive domain route for domain network resource")
}
func TestNetworkMapComponents_NetworkEmpty(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
nm := networkMapFromComponents(t, account, "nonexistent-peer", validated)
assert.NotNil(t, nm)
assert.Empty(t, nm.Peers)
assert.Empty(t, nm.FirewallRules)
assert.NotNil(t, nm.Network)
}
func TestNetworkMapComponents_RouterExcludesOtherNetworkRoutes(t *testing.T) {
account := createComponentTestAccount()
validated := allPeersValidated(account)
account.NetworkResources = append(account.NetworkResources, &resourceTypes.NetworkResource{
ID: "resource-other", NetworkID: "net-other", AccountID: account.Id, Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.99.1/32"), Address: "10.200.99.1/32",
})
account.Networks = append(account.Networks, &networkTypes.Network{
ID: "net-other", Name: "Other Net", AccountID: account.Id,
})
account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{
ID: "router-other", NetworkID: "net-other", Peer: "peer-dst-1", Enabled: true, AccountID: account.Id,
})
account.Groups["group-res-other"] = &types.Group{ID: "group-res-other", Name: "Other resource group",
Resources: []types.Resource{{ID: "resource-other"}},
}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-other-resource", Name: "Other resource policy", Enabled: true, AccountID: account.Id,
Rules: []*types.PolicyRule{{
ID: "rule-other", Name: "Other resource access", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-other"},
}},
})
nm := networkMapFromComponents(t, account, "peer-router-1", validated)
for _, r := range nm.Routes {
assert.NotEqual(t, "10.200.99.1/32", r.Network.String(), "router-1 should NOT get routes for other network's resources")
}
}
func createComponentTestAccount() *types.Account {
peers := map[string]*nbpeer.Peer{
"peer-src-1": {
ID: "peer-src-1", IP: net.IP{100, 64, 0, 1}, Key: "key-src-1", DNSLabel: "src1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-src-2": {
ID: "peer-src-2", IP: net.IP{100, 64, 0, 2}, Key: "key-src-2", DNSLabel: "src2",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-dst-1": {
ID: "peer-dst-1", IP: net.IP{100, 64, 0, 3}, Key: "key-dst-1", DNSLabel: "dst1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-2",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
"peer-router-1": {
ID: "peer-router-1", IP: net.IP{100, 64, 0, 10}, Key: "key-router-1", DNSLabel: "router1",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, UserID: "user-1",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.35.0", GoOS: "linux"},
},
}
groups := map[string]*types.Group{
"group-src": {ID: "group-src", Name: "Sources", Peers: []string{"peer-src-1", "peer-src-2"}},
"group-dst": {ID: "group-dst", Name: "Destinations", Peers: []string{"peer-dst-1"}},
"group-all": {ID: "group-all", Name: "All", Peers: []string{"peer-src-1", "peer-src-2", "peer-dst-1", "peer-router-1"}},
"group-res": {
ID: "group-res", Name: "Resource Group",
Resources: []types.Resource{{ID: "resource-1"}},
},
}
policies := []*types.Policy{
{
ID: "policy-base", Name: "Base connectivity", Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-base", Name: "Allow src <-> dst", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Bidirectional: true,
Sources: []string{"group-src"}, Destinations: []string{"group-dst"},
}},
},
{
ID: "policy-resource", Name: "Network resource access", Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-resource", Name: "Source -> Resource", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolALL,
Sources: []string{"group-src"},
DestinationResource: types.Resource{ID: "resource-1"},
}},
},
}
routes := map[route.ID]*route.Route{
"route-main": {
ID: "route-main", Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-dst-1"].Key, PeerID: "peer-dst-1",
Enabled: true, Metric: 100,
Groups: []string{"group-src", "group-dst"}, PeerGroups: []string{"group-dst"},
},
}
users := map[string]*types.User{
"user-1": {Id: "user-1", Role: types.UserRoleAdmin, IsServiceUser: false, AutoGroups: []string{"group-all"}},
"user-2": {Id: "user-2", Role: types.UserRoleUser, IsServiceUser: false, AutoGroups: []string{"group-all"}},
}
account := &types.Account{
Id: "account-components-test", Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{"group-dst"}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
"ns-main": {
ID: "ns-main", Name: "Main NS", Enabled: true, Groups: []string{"group-src"},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "resource-1", NetworkID: "net-1", AccountID: "account-components-test", Enabled: true,
Type: resourceTypes.Host, Prefix: netip.MustParsePrefix("10.200.0.1/32"), Address: "10.200.0.1/32",
},
},
Networks: []*networkTypes.Network{
{ID: "net-1", Name: "Resource Net", AccountID: "account-components-test"},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: "router-1", NetworkID: "net-1", Peer: "peer-router-1", Enabled: true, AccountID: "account-components-test"},
},
Settings: &types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: 24 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}

View File

@@ -0,0 +1,967 @@
package types_test
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"slices"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
sshUsersGroupID = "group-ssh-users"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
policyIDSSH = "policy-ssh"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map.
expiredPeerID = "peer-98" // This peer will be online but with an expired session.
offlinePeerID = "peer-99" // This peer will be completely offline.
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
testAccountID = "account-golden-test"
userAdminID = "user-admin"
userDevID = "user-dev"
userOpsID = "user-ops"
)
func TestGetPeerNetworkMap_Golden(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
b.ResetTimer()
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder", func(b *testing.B) {
for range b.N {
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
for _, peerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeerIP := net.IP{100, 64, 1, 1}
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: newPeerIP,
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newPeerID] = newPeer
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = append(devGroup.Peers, newPeerID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newPeerID)
}
validatedPeersMap[newPeerID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerAddedIncremental(account, newPeerID)
require.NoError(t, err, "error adding peer to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: net.IP{100, 64, 1, 1},
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
}
account.Peers[newPeerID] = newPeer
account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID)
account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID)
validatedPeersMap[newPeerID] = struct{}{}
b.ResetTimer()
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerAddedIncremental(account, newRouterID)
require.NoError(t, err, "error adding router to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
b.ResetTimer()
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newRouterID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
delete(validatedPeersMap, deletedPeerID)
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerDeleted(account, deletedPeerID)
require.NoError(t, err, "error deleting peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
}
}
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedRouterID := "peer-75"
var affectedRoute *route.Route
for _, r := range account.Routes {
if r.PeerID == deletedRouterID {
affectedRoute = r
break
}
}
require.NotNil(t, affectedRoute, "Router peer should have a route")
for _, group := range account.Groups {
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
return id == deletedRouterID
})
}
for routeID, r := range account.Routes {
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
delete(account.Routes, routeID)
}
}
delete(account.Peers, deletedRouterID)
delete(validatedPeersMap, deletedRouterID)
if account.Network != nil {
account.Network.Serial++
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
err = builder.OnPeerDeleted(account, deletedRouterID)
require.NoError(t, err, "error deleting routing peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
var peerIDs []string
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
validatedPeersMap[peerID] = struct{}{}
peerIDs = append(peerIDs, peerID)
}
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool {
return id == deletedPeerID
})
account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool {
return id == deletedPeerID
})
delete(validatedPeersMap, deletedPeerID)
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
b.ResetTimer()
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, []*zones.Zone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
b.ResetTimer()
b.Run("new builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerDeleted(account, deletedPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
}
func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
for _, peer := range networkMap.Peers {
if peer.Status != nil {
peer.Status.LastSeen = time.Time{}
}
peer.LastLogin = &time.Time{}
}
for _, peer := range networkMap.OfflinePeers {
if peer.Status != nil {
peer.Status.LastSeen = time.Time{}
}
peer.LastLogin = &time.Time{}
}
sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID })
sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID })
sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID })
sort.Slice(networkMap.FirewallRules, func(i, j int) bool {
r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j]
if r1.PeerIP != r2.PeerIP {
return r1.PeerIP < r2.PeerIP
}
if r1.Protocol != r2.Protocol {
return r1.Protocol < r2.Protocol
}
if r1.Direction != r2.Direction {
return r1.Direction < r2.Direction
}
if r1.Action != r2.Action {
return r1.Action < r2.Action
}
return r1.Port < r2.Port
})
sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool {
r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j]
if r1.RouteID != r2.RouteID {
return r1.RouteID < r2.RouteID
}
if r1.Action != r2.Action {
return r1.Action < r2.Action
}
if r1.Destination != r2.Destination {
return r1.Destination < r2.Destination
}
if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 {
if r1.SourceRanges[0] != r2.SourceRanges[0] {
return r1.SourceRanges[0] < r2.SourceRanges[0]
}
}
return r1.Port < r2.Port
})
for _, ranges := range networkMap.RoutesFirewallRules {
sort.Slice(ranges.SourceRanges, func(i, j int) bool {
return ranges.SourceRanges[i] < ranges.SourceRanges[j]
})
}
}
type networkMapJSON struct {
Peers []*nbpeer.Peer `json:"Peers"`
Network *types.Network `json:"Network"`
Routes []*route.Route `json:"Routes"`
DNSConfig dns.Config `json:"DNSConfig"`
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
EnableSSH bool `json:"EnableSSH"`
}
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
result := &networkMapJSON{
Peers: nm.Peers,
Network: nm.Network,
Routes: nm.Routes,
DNSConfig: nm.DNSConfig,
OfflinePeers: nm.OfflinePeers,
FirewallRules: nm.FirewallRules,
RoutesFirewallRules: nm.RoutesFirewallRules,
ForwardingRules: nm.ForwardingRules,
EnableSSH: nm.EnableSSH,
}
if len(nm.AuthorizedUsers) > 0 {
result.AuthorizedUsers = make(map[string][]string)
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
for localUser := range nm.AuthorizedUsers {
localUsers = append(localUsers, localUser)
}
sort.Strings(localUsers)
for _, localUser := range localUsers {
userIDs := nm.AuthorizedUsers[localUser]
sortedUserIDs := make([]string, 0, len(userIDs))
for userID := range userIDs {
sortedUserIDs = append(sortedUserIDs, userID)
}
sort.Strings(sortedUserIDs)
result.AuthorizedUsers[localUser] = sortedUserIDs
}
}
return result
}
func createTestAccountWithEntities() *types.Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*types.Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
}
policies := []*types.Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop,
Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*types.PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
}},
},
{
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
users := map[string]*types.User{
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
}
account := &types.Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*dns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
time.Sleep(100 * time.Millisecond)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
t.Log("Update golden file with OnPeerAdded router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
}

File diff suppressed because it is too large Load Diff

View File

@@ -146,11 +146,7 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string
userJWTGroups := make([]string, 0)
if claim, ok := claims[claimName]; ok {
switch claimGroups := claim.(type) {
case string:
// Some IdPs emit a single group claim as a string instead of an array.
userJWTGroups = append(userJWTGroups, claimGroups)
case []any:
if claimGroups, ok := claim.([]interface{}); ok {
for _, g := range claimGroups {
if group, ok := g.(string); ok {
userJWTGroups = append(userJWTGroups, group)
@@ -158,11 +154,9 @@ func (c *ClaimsExtractor) ToGroups(token *jwt.Token, claimName string) []string
log.Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
}
}
default:
log.Debugf("JWT claim %q is not a string or string array (type: %T): %v", claimName, claim, claim)
}
} else {
log.Debugf("JWT claim %q is missing", claimName)
log.Debugf("JWT claim %q is not a string array", claimName)
}
return userJWTGroups

View File

@@ -249,15 +249,6 @@ func TestClaimsExtractor_ToGroups(t *testing.T) {
groupClaimName: "groups",
expectedGroups: []string{},
},
{
name: "extracts single group string from claim",
claims: jwt.MapClaims{
"sub": "user-123",
"groups": "admin",
},
groupClaimName: "groups",
expectedGroups: []string{"admin"},
},
{
name: "handles custom claim name",
claims: jwt.MapClaims{

View File

@@ -252,19 +252,21 @@ func (c *GrpcClient) handleJobStream(
c.notifyDisconnected(err)
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("job stream context has been canceled, this usually indicates shutdown")
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return err
case codes.Unimplemented:
log.Warn("Job feature is not supported by the current management server version. " +
"Please update the management service to use this feature.")
return nil
default:
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
} else {
// non-gRPC error
log.Warnf("job stream disconnected, will retry silently. Reason: %v", err)
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}

View File

@@ -2917,7 +2917,6 @@ components:
- okta
- pocketid
- microsoft
- adfs
example: oidc
IdentityProvider:
type: object

View File

@@ -518,7 +518,6 @@ const (
IdentityProviderTypeOkta IdentityProviderType = "okta"
IdentityProviderTypePocketid IdentityProviderType = "pocketid"
IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
IdentityProviderTypeAdfs IdentityProviderType = "adfs"
)
// Valid indicates whether the value is a known member of the IdentityProviderType enum.
@@ -538,8 +537,6 @@ func (e IdentityProviderType) Valid() bool {
return true
case IdentityProviderTypeZitadel:
return true
case IdentityProviderTypeAdfs:
return true
default:
return false
}

View File

@@ -8,7 +8,10 @@ import (
log "github.com/sirupsen/logrus"
)
const defaultMaxBackoffInterval = 60 * time.Second
const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct {
@@ -16,23 +19,14 @@ type Guard struct {
OnNewRelayClient chan *Client
OnReconnected chan struct{}
serverPicker *ServerPicker
// maxBackoffInterval caps the exponential backoff between reconnect
// attempts.
maxBackoffInterval time.Duration
}
// NewGuard creates a new guard for the relay client. A non-positive
// maxBackoffInterval falls back to defaultMaxBackoffInterval.
func NewGuard(sp *ServerPicker, maxBackoffInterval time.Duration) *Guard {
if maxBackoffInterval <= 0 {
maxBackoffInterval = defaultMaxBackoffInterval
}
// NewGuard creates a new guard for the relay client.
func NewGuard(sp *ServerPicker) *Guard {
g := &Guard{
OnNewRelayClient: make(chan *Client, 1),
OnReconnected: make(chan struct{}, 1),
serverPicker: sp,
maxBackoffInterval: maxBackoffInterval,
OnNewRelayClient: make(chan *Client, 1),
OnReconnected: make(chan struct{}, 1),
serverPicker: sp,
}
return g
}
@@ -55,7 +49,7 @@ func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
}
// start a ticker to pick a new server
ticker := g.exponentTicker(ctx)
ticker := exponentTicker(ctx)
defer ticker.Stop()
for {
@@ -131,11 +125,11 @@ func (g *Guard) notifyReconnected() {
}
}
func (g *Guard) exponentTicker(ctx context.Context) *backoff.Ticker {
func exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second,
Multiplier: 2,
MaxInterval: g.maxBackoffInterval,
MaxInterval: reconnectingTimeout,
Clock: backoff.SystemClock,
}, ctx)

View File

@@ -39,15 +39,6 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func()
// ManagerOption configures a Manager at construction time.
type ManagerOption func(*Manager)
// WithMaxBackoffInterval caps the exponential backoff between reconnect
// attempts to the home relay. A non-positive value keeps the default.
func WithMaxBackoffInterval(d time.Duration) ManagerOption {
return func(m *Manager) { m.maxBackoffInterval = d }
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
@@ -73,13 +64,12 @@ type Manager struct {
onReconnectedListenerFn func()
listenerLock sync.Mutex
mtu uint16
maxBackoffInterval time.Duration
mtu uint16
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16, opts ...ManagerOption) *Manager {
func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uint16) *Manager {
tokenStore := &relayAuth.TokenStore{}
m := &Manager{
@@ -96,11 +86,8 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
}
for _, opt := range opts {
opt(m)
}
m.serverPicker.ServerURLs.Store(serverURLs)
m.reconnectGuard = NewGuard(m.serverPicker, m.maxBackoffInterval)
m.reconnectGuard = NewGuard(m.serverPicker)
return m
}
@@ -303,36 +290,19 @@ func (m *Manager) onServerConnected() {
go m.onReconnectedListenerFn()
}
// onServerDisconnected handles relay disconnect events. For the home server it
// starts the reconnect guard. For foreign servers it evicts the now-dead client
// from the cache so the next OpenConn builds a fresh one instead of reusing a
// closed client.
// onServerDisconnected start to reconnection for home server only
func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock()
isHome := m.relayClient != nil && serverAddress == m.relayClient.connectionURL
if isHome {
if serverAddress == m.relayClient.connectionURL {
go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient)
}
m.relayClientMu.Unlock()
if !isHome {
m.evictForeignRelay(serverAddress)
}
m.notifyOnDisconnectListeners(serverAddress)
}
func (m *Manager) evictForeignRelay(serverAddress string) {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
if _, ok := m.relayClients[serverAddress]; ok {
delete(m.relayClients, serverAddress)
log.Debugf("evicted disconnected foreign relay client: %s", serverAddress)
}
}
func (m *Manager) listenGuardEvent(ctx context.Context) {
for {
select {

View File

@@ -2,7 +2,6 @@ package client
import (
"context"
"fmt"
"testing"
"time"
@@ -361,8 +360,7 @@ func TestAutoReconnect(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU,
WithMaxBackoffInterval(2*time.Second))
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice", iface.DefaultMTU)
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
@@ -386,9 +384,7 @@ func TestAutoReconnect(t *testing.T) {
}
log.Infof("waiting for reconnection")
if err := waitForReady(ctx, clientAlice, 15*time.Second); err != nil {
t.Fatalf("manager did not reconnect: %s", err)
}
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ctx, ra, "bob")
@@ -397,21 +393,6 @@ func TestAutoReconnect(t *testing.T) {
}
}
func waitForReady(ctx context.Context, m *Manager, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if m.Ready() {
return nil
}
select {
case <-time.After(100 * time.Millisecond):
case <-ctx.Done():
return ctx.Err()
}
}
return fmt.Errorf("manager not ready within %s", timeout)
}
func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background()