mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 02:42:08 -04:00
Compare commits
5 Commits
github-iss
...
revert-dns
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1c0d90c64 | ||
|
|
f732b01a05 | ||
|
|
c07c726ea7 | ||
|
|
fa0d58d093 | ||
|
|
b6038e8acd |
@@ -1,26 +0,0 @@
|
||||
You are a GitHub issue resolution classifier.
|
||||
|
||||
Your job is to decide whether an open GitHub issue is:
|
||||
- AUTO_CLOSE
|
||||
- MANUAL_REVIEW
|
||||
- KEEP_OPEN
|
||||
|
||||
Rules:
|
||||
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
|
||||
- a merged linked PR that clearly resolves the issue, or
|
||||
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
|
||||
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
|
||||
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
|
||||
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
|
||||
5. Do not invent evidence.
|
||||
6. Output valid JSON only.
|
||||
|
||||
Maintainer-authoritative roles:
|
||||
- MEMBER
|
||||
- OWNER
|
||||
- COLLABORATOR
|
||||
|
||||
Important:
|
||||
- Later comments outweigh earlier ones.
|
||||
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
|
||||
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.
|
||||
@@ -1,78 +0,0 @@
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"decision",
|
||||
"reason_code",
|
||||
"confidence",
|
||||
"hard_signals",
|
||||
"contradictions",
|
||||
"summary",
|
||||
"close_comment",
|
||||
"manual_review_note"
|
||||
],
|
||||
"properties": {
|
||||
"decision": {
|
||||
"type": "string",
|
||||
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
|
||||
},
|
||||
"reason_code": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"resolved_by_merged_pr",
|
||||
"maintainer_confirmed_resolved",
|
||||
"duplicate_confirmed",
|
||||
"superseded_confirmed",
|
||||
"likely_fixed_but_unconfirmed",
|
||||
"still_open",
|
||||
"unclear"
|
||||
]
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
},
|
||||
"hard_signals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["type", "url"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"merged_pr",
|
||||
"maintainer_comment",
|
||||
"duplicate_reference",
|
||||
"superseded_reference"
|
||||
]
|
||||
},
|
||||
"url": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"contradictions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["type", "url"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"reporter_still_broken",
|
||||
"later_unresolved_comment",
|
||||
"ambiguous_pr_link",
|
||||
"other"
|
||||
]
|
||||
},
|
||||
"url": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"summary": { "type": "string" },
|
||||
"close_comment": { "type": "string" },
|
||||
"manual_review_note": { "type": "string" }
|
||||
}
|
||||
}
|
||||
152
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
152
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
@@ -1,152 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
|
||||
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
||||
Accept: "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
};
|
||||
|
||||
async function rest(url, method = "GET", body) {
|
||||
const res = await fetch(url, {
|
||||
method,
|
||||
headers,
|
||||
body: body ? JSON.stringify(body) : undefined
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
||||
return res.status === 204 ? null : res.json();
|
||||
}
|
||||
|
||||
async function graphql(query, variables) {
|
||||
const res = await fetch("https://api.github.com/graphql", {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({ query, variables })
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
|
||||
const json = await res.json();
|
||||
if (json.errors) throw new Error(JSON.stringify(json.errors));
|
||||
return json.data;
|
||||
}
|
||||
|
||||
async function addLabel(owner, repo, issueNumber, labels) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
|
||||
"POST",
|
||||
{ labels }
|
||||
);
|
||||
}
|
||||
|
||||
async function addComment(owner, repo, issueNumber, body) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
|
||||
"POST",
|
||||
{ body }
|
||||
);
|
||||
}
|
||||
|
||||
async function closeIssue(owner, repo, issueNumber) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
|
||||
"PATCH",
|
||||
{ state: "closed", state_reason: "completed" }
|
||||
);
|
||||
}
|
||||
|
||||
async function getIssueNodeId(owner, repo, issueNumber) {
|
||||
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
|
||||
return issue.node_id;
|
||||
}
|
||||
|
||||
async function addToProject(issueNodeId) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $contentId: ID!) {
|
||||
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
|
||||
item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
const data = await graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
contentId: issueNodeId
|
||||
});
|
||||
|
||||
return data.addProjectV2ItemById.item.id;
|
||||
}
|
||||
|
||||
async function setTextField(itemId, fieldId, value) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
|
||||
updateProjectV2ItemFieldValue(input: {
|
||||
projectId: $projectId,
|
||||
itemId: $itemId,
|
||||
fieldId: $fieldId,
|
||||
value: { text: $value }
|
||||
}) {
|
||||
projectV2Item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
return graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
itemId,
|
||||
fieldId,
|
||||
value
|
||||
});
|
||||
}
|
||||
|
||||
for (const d of decisions) {
|
||||
const [owner, repo] = d.repository.split("/");
|
||||
|
||||
if (d.final_decision === "AUTO_CLOSE") {
|
||||
if (dryRun) continue;
|
||||
|
||||
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
|
||||
await addComment(
|
||||
owner,
|
||||
repo,
|
||||
d.issue_number,
|
||||
d.model.close_comment ||
|
||||
"This appears resolved based on linked evidence, so we’re closing it automatically. Reply if this still reproduces and we’ll reopen."
|
||||
);
|
||||
await closeIssue(owner, repo, d.issue_number);
|
||||
}
|
||||
|
||||
if (d.final_decision === "MANUAL_REVIEW") {
|
||||
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
|
||||
|
||||
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
|
||||
const itemId = await addToProject(issueNodeId);
|
||||
|
||||
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
|
||||
}
|
||||
if (process.env.PROJECT_REASON_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
|
||||
}
|
||||
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
|
||||
}
|
||||
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
|
||||
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
|
||||
if (linked) {
|
||||
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
|
||||
}
|
||||
}
|
||||
if (process.env.PROJECT_REPO_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
|
||||
}
|
||||
|
||||
await addComment(
|
||||
owner,
|
||||
repo,
|
||||
d.issue_number,
|
||||
d.model.manual_review_note ||
|
||||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
|
||||
|
||||
function isMaintainerRole(role) {
|
||||
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
|
||||
}
|
||||
|
||||
function preScore(candidate) {
|
||||
let score = 0;
|
||||
const hardSignals = [];
|
||||
const contradictions = [];
|
||||
|
||||
for (const t of candidate.timeline) {
|
||||
const sourceIssue = t.source?.issue;
|
||||
|
||||
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
|
||||
hardSignals.push({
|
||||
type: "merged_pr",
|
||||
url: sourceIssue.html_url
|
||||
});
|
||||
score += 40; // provisional until PR merged state is verified
|
||||
}
|
||||
|
||||
if (["referenced", "connected"].includes(t.event)) {
|
||||
score += 10;
|
||||
}
|
||||
}
|
||||
|
||||
for (const c of candidate.comments) {
|
||||
const body = c.body.toLowerCase();
|
||||
|
||||
if (
|
||||
isMaintainerRole(c.author_association) &&
|
||||
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
|
||||
) {
|
||||
score += 25;
|
||||
hardSignals.push({
|
||||
type: "maintainer_comment",
|
||||
url: c.html_url
|
||||
});
|
||||
}
|
||||
|
||||
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
|
||||
score -= 50;
|
||||
contradictions.push({
|
||||
type: "later_unresolved_comment",
|
||||
url: c.html_url
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return { score, hardSignals, contradictions };
|
||||
}
|
||||
|
||||
async function callGitHubModel(issuePacket) {
|
||||
// Replace this stub with the GitHub Models inference call used by your org.
|
||||
// The workflow already has models: read permission.
|
||||
return {
|
||||
decision: "MANUAL_REVIEW",
|
||||
reason_code: "likely_fixed_but_unconfirmed",
|
||||
confidence: 0.74,
|
||||
hard_signals: [],
|
||||
contradictions: [],
|
||||
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
|
||||
close_comment: "This appears resolved, so we’re closing it automatically. Reply if this is still reproducible.",
|
||||
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
|
||||
};
|
||||
}
|
||||
|
||||
function enforcePolicy(modelOut, pre) {
|
||||
const approvedReasons = new Set([
|
||||
"resolved_by_merged_pr",
|
||||
"maintainer_confirmed_resolved",
|
||||
"duplicate_confirmed",
|
||||
"superseded_confirmed"
|
||||
]);
|
||||
|
||||
const hasHardSignal =
|
||||
(modelOut.hard_signals || []).some(s =>
|
||||
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
|
||||
) || pre.hardSignals.length > 0;
|
||||
|
||||
const hasContradiction =
|
||||
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
|
||||
|
||||
if (
|
||||
modelOut.decision === "AUTO_CLOSE" &&
|
||||
modelOut.confidence >= 0.97 &&
|
||||
approvedReasons.has(modelOut.reason_code) &&
|
||||
hasHardSignal &&
|
||||
!hasContradiction
|
||||
) {
|
||||
return "AUTO_CLOSE";
|
||||
}
|
||||
|
||||
if (
|
||||
modelOut.decision === "MANUAL_REVIEW" ||
|
||||
modelOut.confidence >= 0.60 ||
|
||||
pre.score >= 25
|
||||
) {
|
||||
return "MANUAL_REVIEW";
|
||||
}
|
||||
|
||||
return "KEEP_OPEN";
|
||||
}
|
||||
|
||||
const decisions = [];
|
||||
for (const candidate of candidates) {
|
||||
const pre = preScore(candidate);
|
||||
const modelOut = await callGitHubModel(candidate);
|
||||
const finalDecision = enforcePolicy(modelOut, pre);
|
||||
|
||||
decisions.push({
|
||||
repository: candidate.repository,
|
||||
issue_number: candidate.issue.number,
|
||||
issue_url: candidate.issue.html_url,
|
||||
title: candidate.issue.title,
|
||||
pre_score: pre.score,
|
||||
final_decision: finalDecision,
|
||||
model: modelOut
|
||||
});
|
||||
}
|
||||
|
||||
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
|
||||
50
.github/workflows/issue-resolution-triage.yml
vendored
50
.github/workflows/issue-resolution-triage.yml
vendored
@@ -1,50 +0,0 @@
|
||||
name: issue-resolution-triage
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry_run:
|
||||
description: "If true, do not close issues"
|
||||
required: false
|
||||
default: "true"
|
||||
max_issues:
|
||||
description: "How many issues to process"
|
||||
required: false
|
||||
default: "100"
|
||||
schedule:
|
||||
- cron: "17 2 * * *"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: read
|
||||
models: read
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DRY_RUN: ${{ inputs.dry_run || 'true' }}
|
||||
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
|
||||
REPO: ${{ github.repository }}
|
||||
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
|
||||
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
|
||||
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
|
||||
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
|
||||
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
|
||||
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
|
||||
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
|
||||
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- run: npm ci
|
||||
- run: node scripts/fetch-candidates.mjs
|
||||
- run: node scripts/classify-candidates.mjs
|
||||
- run: node scripts/apply-decisions.mjs
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
osManager, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
break
|
||||
return fileManager, nil
|
||||
}
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
return 0, "", rejected, nil
|
||||
|
||||
return fileManager, nil
|
||||
}
|
||||
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -178,36 +139,3 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
limiter := middleware.NewAPIRateLimiter(cfg)
|
||||
limiter.SetEnabled(enabled)
|
||||
return limiter
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
|
||||
@@ -2311,6 +2311,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanUp)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
// Verify the already-expired peer is excluded at the store level
|
||||
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
|
||||
}
|
||||
|
||||
// Only the non-expired peer with expiration enabled should be returned
|
||||
require.Len(t, peers, 1)
|
||||
assert.Equal(t, "notexpired01", peers[0].ID)
|
||||
}
|
||||
|
||||
func TestAccount_GetInactivePeers(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
@@ -3230,6 +3253,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
return manager, updateManager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
|
||||
// wrappers wait for an expected update message. Sized for slow CI runners
|
||||
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
|
||||
// seconds. Only runs down on failure; passing tests return immediately
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
@@ -3248,7 +3278,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -5,9 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
@@ -66,14 +63,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
apiPrefix = "/api"
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||
}
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
if rateLimiter == nil {
|
||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
||||
rateLimiter.SetEnabled(false)
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
)
|
||||
|
||||
|
||||
@@ -43,14 +43,9 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
var err error
|
||||
@@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
disabledLimiter,
|
||||
nil,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,14 +4,27 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
const (
|
||||
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
|
||||
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
|
||||
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
|
||||
|
||||
defaultAPIRPM = 6
|
||||
defaultAPIBurst = 500
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
|
||||
rpm := defaultAPIRPM
|
||||
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
if rpm <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
|
||||
rpm = defaultAPIRPM
|
||||
}
|
||||
|
||||
burst := defaultAPIBurst
|
||||
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
if burst <= 0 {
|
||||
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
|
||||
burst = defaultAPIBurst
|
||||
}
|
||||
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}, os.Getenv(RateLimitingEnabledEnv) == "true"
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
enabled atomic.Bool
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
rl.enabled.Store(true)
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
|
||||
rl.enabled.Store(enabled)
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) Enabled() bool {
|
||||
return rl.enabled.Load()
|
||||
}
|
||||
|
||||
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
|
||||
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
|
||||
return
|
||||
}
|
||||
|
||||
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
|
||||
newBurst := config.Burst
|
||||
|
||||
rl.mu.Lock()
|
||||
rl.config.RequestsPerMinute = config.RequestsPerMinute
|
||||
rl.config.Burst = newBurst
|
||||
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
|
||||
for _, entry := range rl.limiters {
|
||||
snapshot = append(snapshot, entry.limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
for _, l := range snapshot {
|
||||
l.SetLimit(newRPS)
|
||||
l.SetBurst(newBurst)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
if !rl.enabled.Load() {
|
||||
return true
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
if !rl.enabled.Load() {
|
||||
return nil
|
||||
}
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !rl.enabled.Load() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("key"))
|
||||
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
|
||||
|
||||
rl.SetEnabled(false)
|
||||
assert.False(t, rl.Enabled())
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
|
||||
}
|
||||
|
||||
rl.SetEnabled(true)
|
||||
assert.True(t, rl.Enabled())
|
||||
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.True(t, rl.Allow("k1"))
|
||||
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
|
||||
// New burst applies to existing keys in place; bucket refills up to new burst over time,
|
||||
// but importantly newly-added keys use the updated config immediately.
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
for i := 0; i < 9; i++ {
|
||||
assert.True(t, rl.Allow("k2"))
|
||||
}
|
||||
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
rl.UpdateConfig(nil) // must not panic or zero the config
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"))
|
||||
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
|
||||
|
||||
rl.Reset("k")
|
||||
assert.True(t, rl.Allow("k"))
|
||||
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 600,
|
||||
Burst: 10,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
key := fmt.Sprintf("k%d", id)
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.Allow(key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 200; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
rl.UpdateConfig(&RateLimiterConfig{
|
||||
RequestsPerMinute: float64(30 + (i % 90)),
|
||||
Burst: 1 + (i % 20),
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
rl.SetEnabled(i%2 == 0)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRateLimiterConfigFromEnv(t *testing.T) {
|
||||
t.Setenv(RateLimitingEnabledEnv, "true")
|
||||
t.Setenv(RateLimitingRPMEnv, "42")
|
||||
t.Setenv(RateLimitingBurstEnv, "7")
|
||||
|
||||
cfg, enabled := RateLimiterConfigFromEnv()
|
||||
assert.True(t, enabled)
|
||||
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, 7, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "false")
|
||||
_, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
|
||||
t.Setenv(RateLimitingEnabledEnv, "")
|
||||
t.Setenv(RateLimitingRPMEnv, "")
|
||||
t.Setenv(RateLimitingBurstEnv, "")
|
||||
cfg, enabled = RateLimiterConfigFromEnv()
|
||||
assert.False(t, enabled)
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst)
|
||||
|
||||
t.Setenv(RateLimitingRPMEnv, "0")
|
||||
t.Setenv(RateLimitingBurstEnv, "-5")
|
||||
cfg, _ = RateLimiterConfigFromEnv()
|
||||
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
|
||||
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
// expired peers come separately.
|
||||
if len(networkMap.GetOfflinePeers()) != 1 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
|
||||
if len(networkMap.GetOfflinePeers()) != 2 {
|
||||
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
|
||||
}
|
||||
|
||||
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="
|
||||
|
||||
@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
|
||||
@@ -1907,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1929,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1994,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2012,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2058,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2076,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2113,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2131,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -2070,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
|
||||
@@ -2107,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2127,7 +2127,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2145,7 +2145,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2185,7 +2185,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -2225,7 +2225,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
result := tx.
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||
|
||||
@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
{
|
||||
name: "should retrieve peers for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 4,
|
||||
expectedCount: 5,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for a non-existing account ID",
|
||||
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
|
||||
name: "should filter peers by partial name",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
nameFilter: "host",
|
||||
expectedCount: 3,
|
||||
expectedCount: 4,
|
||||
},
|
||||
{
|
||||
name: "should filter peers by ip",
|
||||
@@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
expectedPeerIDs []string
|
||||
}{
|
||||
{
|
||||
name: "should retrieve peers with expiration for an existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
name: "should retrieve only non-expired peers with expiration enabled",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
expectedPeerIDs: []string{"notexpired01"},
|
||||
},
|
||||
{
|
||||
name: "should return no peers with expiration for a non-existing account ID",
|
||||
@@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, peers, tt.expectedCount)
|
||||
for i, peer := range peers {
|
||||
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
|
||||
for _, peer := range peers {
|
||||
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
|
||||
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
@@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
|
||||
name: "should retrieve peers for another valid account ID and user ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
|
||||
expectedCount: 2,
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "should return no peers for existing account ID with empty user ID",
|
||||
|
||||
@@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
|
||||
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
@@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
@@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout waiting for peerShouldReceiveUpdate")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
||||
@@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestSetSessionCookieHasRootPath(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
setSessionCookie(w, "test-token", time.Hour)
|
||||
|
||||
cookies := w.Result().Cookies()
|
||||
require.Len(t, cookies, 1)
|
||||
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil, nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
Reference in New Issue
Block a user