Compare commits

..

5 Commits

Author SHA1 Message Date
Zoltán Papp
b1c0d90c64 Revert "[client] Prefer systemd-resolved stub over file mode regardless of resolv.conf header (#5935)"
This reverts commit 75e408f51c.
2026-04-24 09:04:32 +02:00
Zoltan Papp
f732b01a05 [management] unify peer-update test timeout via constant (#5952)
peerShouldReceiveUpdate waited 500ms for the expected update message,
and every outer wrapper across the management/server test suite paired
it with a 1s goroutine-drain timeout. Both were too tight for slower
CI runners (MySQL, FreeBSD, loaded sqlite), producing intermittent
"Timed out waiting for update message" failures in tests like
TestDNSAccountPeersUpdate, TestPeerAccountPeersUpdate, and
TestNameServerAccountPeersUpdate.

Introduce peerUpdateTimeout (5s) next to the helper and use it both in
the helper and in every outer wrapper so the two timeouts stay in sync.
Only runs down on failure; passing tests return as soon as the channel
delivers, so there is no slowdown on green runs.
2026-04-23 21:19:21 +02:00
alsruf36
c07c726ea7 [proxy] Set session cookie path to root (#5915) 2026-04-23 18:20:54 +02:00
Pascal Fischer
fa0d58d093 [management] exclude peers for expiration job that have already been marked expired (#5970) 2026-04-23 16:01:54 +02:00
Vlad
b6038e8acd [management] refactor: changeable pat rate limiting (#5946) 2026-04-23 15:13:22 +02:00
31 changed files with 455 additions and 723 deletions

View File

@@ -1,26 +0,0 @@
You are a GitHub issue resolution classifier.
Your job is to decide whether an open GitHub issue is:
- AUTO_CLOSE
- MANUAL_REVIEW
- KEEP_OPEN
Rules:
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
- a merged linked PR that clearly resolves the issue, or
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
5. Do not invent evidence.
6. Output valid JSON only.
Maintainer-authoritative roles:
- MEMBER
- OWNER
- COLLABORATOR
Important:
- Later comments outweigh earlier ones.
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.

View File

@@ -1,78 +0,0 @@
{
"type": "object",
"additionalProperties": false,
"required": [
"decision",
"reason_code",
"confidence",
"hard_signals",
"contradictions",
"summary",
"close_comment",
"manual_review_note"
],
"properties": {
"decision": {
"type": "string",
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
},
"reason_code": {
"type": "string",
"enum": [
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed",
"likely_fixed_but_unconfirmed",
"still_open",
"unclear"
]
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"hard_signals": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"merged_pr",
"maintainer_comment",
"duplicate_reference",
"superseded_reference"
]
},
"url": { "type": "string" }
}
}
},
"contradictions": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"reporter_still_broken",
"later_unresolved_comment",
"ambiguous_pr_link",
"other"
]
},
"url": { "type": "string" }
}
}
},
"summary": { "type": "string" },
"close_comment": { "type": "string" },
"manual_review_note": { "type": "string" }
}
}

View File

@@ -1,152 +0,0 @@
import fs from "node:fs/promises";
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
const headers = {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url, method = "GET", body) {
const res = await fetch(url, {
method,
headers,
body: body ? JSON.stringify(body) : undefined
});
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.status === 204 ? null : res.json();
}
async function graphql(query, variables) {
const res = await fetch("https://api.github.com/graphql", {
method: "POST",
headers,
body: JSON.stringify({ query, variables })
});
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
const json = await res.json();
if (json.errors) throw new Error(JSON.stringify(json.errors));
return json.data;
}
async function addLabel(owner, repo, issueNumber, labels) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
"POST",
{ labels }
);
}
async function addComment(owner, repo, issueNumber, body) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
"POST",
{ body }
);
}
async function closeIssue(owner, repo, issueNumber) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
"PATCH",
{ state: "closed", state_reason: "completed" }
);
}
async function getIssueNodeId(owner, repo, issueNumber) {
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
return issue.node_id;
}
async function addToProject(issueNodeId) {
const mutation = `
mutation($projectId: ID!, $contentId: ID!) {
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
item { id }
}
}
`;
const data = await graphql(mutation, {
projectId: process.env.PROJECT_ID,
contentId: issueNodeId
});
return data.addProjectV2ItemById.item.id;
}
async function setTextField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { text: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
for (const d of decisions) {
const [owner, repo] = d.repository.split("/");
if (d.final_decision === "AUTO_CLOSE") {
if (dryRun) continue;
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
await addComment(
owner,
repo,
d.issue_number,
d.model.close_comment ||
"This appears resolved based on linked evidence, so were closing it automatically. Reply if this still reproduces and well reopen."
);
await closeIssue(owner, repo, d.issue_number);
}
if (d.final_decision === "MANUAL_REVIEW") {
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
const itemId = await addToProject(issueNodeId);
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
}
if (process.env.PROJECT_REASON_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
}
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
}
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
if (linked) {
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
}
}
if (process.env.PROJECT_REPO_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
}
await addComment(
owner,
repo,
d.issue_number,
d.model.manual_review_note ||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
);
}
}

View File

@@ -1,125 +0,0 @@
import fs from "node:fs/promises";
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
function isMaintainerRole(role) {
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
}
function preScore(candidate) {
let score = 0;
const hardSignals = [];
const contradictions = [];
for (const t of candidate.timeline) {
const sourceIssue = t.source?.issue;
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
hardSignals.push({
type: "merged_pr",
url: sourceIssue.html_url
});
score += 40; // provisional until PR merged state is verified
}
if (["referenced", "connected"].includes(t.event)) {
score += 10;
}
}
for (const c of candidate.comments) {
const body = c.body.toLowerCase();
if (
isMaintainerRole(c.author_association) &&
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
) {
score += 25;
hardSignals.push({
type: "maintainer_comment",
url: c.html_url
});
}
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
score -= 50;
contradictions.push({
type: "later_unresolved_comment",
url: c.html_url
});
}
}
return { score, hardSignals, contradictions };
}
async function callGitHubModel(issuePacket) {
// Replace this stub with the GitHub Models inference call used by your org.
// The workflow already has models: read permission.
return {
decision: "MANUAL_REVIEW",
reason_code: "likely_fixed_but_unconfirmed",
confidence: 0.74,
hard_signals: [],
contradictions: [],
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
close_comment: "This appears resolved, so were closing it automatically. Reply if this is still reproducible.",
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
};
}
function enforcePolicy(modelOut, pre) {
const approvedReasons = new Set([
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed"
]);
const hasHardSignal =
(modelOut.hard_signals || []).some(s =>
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
) || pre.hardSignals.length > 0;
const hasContradiction =
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
if (
modelOut.decision === "AUTO_CLOSE" &&
modelOut.confidence >= 0.97 &&
approvedReasons.has(modelOut.reason_code) &&
hasHardSignal &&
!hasContradiction
) {
return "AUTO_CLOSE";
}
if (
modelOut.decision === "MANUAL_REVIEW" ||
modelOut.confidence >= 0.60 ||
pre.score >= 25
) {
return "MANUAL_REVIEW";
}
return "KEEP_OPEN";
}
const decisions = [];
for (const candidate of candidates) {
const pre = preScore(candidate);
const modelOut = await callGitHubModel(candidate);
const finalDecision = enforcePolicy(modelOut, pre);
decisions.push({
repository: candidate.repository,
issue_number: candidate.issue.number,
issue_url: candidate.issue.html_url,
title: candidate.issue.title,
pre_score: pre.score,
final_decision: finalDecision,
model: modelOut
});
}
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));

View File

@@ -1,50 +0,0 @@
name: issue-resolution-triage
on:
workflow_dispatch:
inputs:
dry_run:
description: "If true, do not close issues"
required: false
default: "true"
max_issues:
description: "How many issues to process"
required: false
default: "100"
schedule:
- cron: "17 2 * * *"
permissions:
contents: read
issues: write
pull-requests: read
models: read
jobs:
triage:
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DRY_RUN: ${{ inputs.dry_run || 'true' }}
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
REPO: ${{ github.repository }}
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
- run: npm ci
- run: node scripts/fetch-candidates.mjs
- run: node scripts/classify-candidates.mjs
- run: node scripts/apply-decisions.mjs

View File

@@ -13,7 +13,6 @@ import (
const (
defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
)
type resolvConf struct {

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
}
func newHostManager(wgInterface string) (hostManager, error) {
osManager, reason, err := getOSDNSManagerType()
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
log.Infof("System DNS manager discovered: %s", osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
}
}
func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
func getOSDNSManagerType() (osManagerType, error) {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
}
defer func() {
if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
}
}()
var rejected []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
continue
}
if text[0] != '#' {
break
return fileManager, nil
}
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return mgr, reason, nil, nil
} else if rej != "" {
rejected = append(rejected, rej)
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
return 0, "", nil, fmt.Errorf("scan: %w", err)
return 0, fmt.Errorf("scan: %w", err)
}
return 0, "", rejected, nil
return fileManager, nil
}
// matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
func checkStub() bool {
rConf, err := parseDefaultResolvConf()
if err != nil {
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
log.Warnf("failed to parse resolv conf: %s", err)
return true
}
@@ -178,36 +139,3 @@ func checkStub() bool {
return false
}
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -1,76 +0,0 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -30,6 +30,7 @@ import (
nbcache "github.com/netbirdio/netbird/management/server/cache"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
@@ -109,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -117,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler {
})
}
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
return Create(s, func() *middleware.APIRateLimiter {
cfg, enabled := middleware.RateLimiterConfigFromEnv()
limiter := middleware.NewAPIRateLimiter(cfg)
limiter.SetEnabled(enabled)
return limiter
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -2311,6 +2311,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) {
}
}
func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) {
ctx := context.Background()
testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanUp)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
// Verify the already-expired peer is excluded at the store level
peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query")
assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired")
}
// Only the non-expired peer with expiration enabled should be returned
require.Len(t, peers, 1)
assert.Equal(t, "notexpired01", peers[0].ID)
}
func TestAccount_GetInactivePeers(t *testing.T) {
type test struct {
name string
@@ -3230,6 +3253,13 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
return manager, updateManager, account, peer1, peer2, peer3
}
// peerUpdateTimeout bounds how long peerShouldReceiveUpdate and its outer
// wrappers wait for an expected update message. Sized for slow CI runners
// (MySQL, FreeBSD, loaded sqlite) where the channel publish can take
// seconds. Only runs down on failure; passing tests return immediately
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3248,7 +3278,7 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.Upd
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
case <-time.After(500 * time.Millisecond):
case <-time.After(peerUpdateTimeout):
t.Error("Timed out waiting for update message")
}
}

View File

@@ -458,7 +458,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -478,7 +478,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -518,7 +518,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -620,7 +620,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -638,7 +638,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -689,7 +689,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -730,7 +730,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -757,7 +757,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -804,7 +804,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -5,9 +5,6 @@ import (
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
@@ -66,14 +63,11 @@ import (
)
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
apiPrefix = "/api"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
}
authMiddleware := middleware.NewAuthMiddleware(
@@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
rateLimiter,
appMetrics.GetMeter(),
)

View File

@@ -43,14 +43,9 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
rateLimiter *APIRateLimiter,
meter metric.Meter,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker
if meter != nil {
var err error
@@ -181,10 +176,8 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
ctx := r.Context()

View File

@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)

View File

@@ -4,14 +4,27 @@ import (
"context"
"net"
"net/http"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
}
}
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
enabled atomic.Bool
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
rl.enabled.Store(true)
go rl.cleanupLoop()
return rl
}
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key)
return limiter.Allow()
}
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,8 +1,10 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -135,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -264,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -267,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) {
}
// expired peers come separately.
if len(networkMap.GetOfflinePeers()) != 1 {
t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer")
if len(networkMap.GetOfflinePeers()) != 2 {
t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer")
}
expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4="

View File

@@ -1087,7 +1087,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1105,7 +1105,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
if peer.Status.LoginExpired {
continue
}
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)

View File

@@ -1907,7 +1907,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1929,7 +1929,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1994,7 +1994,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2012,7 +2012,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2058,7 +2058,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2076,7 +2076,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2113,7 +2113,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2131,7 +2131,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -1231,7 +1231,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1263,7 +1263,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1294,7 +1294,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1314,7 +1314,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1355,7 +1355,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1373,7 +1373,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
@@ -1393,7 +1393,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -244,7 +244,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -273,7 +273,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -292,7 +292,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -395,7 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -438,7 +438,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -2070,7 +2070,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
@@ -2107,7 +2107,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2127,7 +2127,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2145,7 +2145,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2185,7 +2185,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -2225,7 +2225,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
var peers []*nbpeer.Peer
result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)

View File

@@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
{
name: "should retrieve peers for an existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 4,
expectedCount: 5,
},
{
name: "should return no peers for a non-existing account ID",
@@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) {
name: "should filter peers by partial name",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
nameFilter: "host",
expectedCount: 3,
expectedCount: 4,
},
{
name: "should filter peers by ip",
@@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
name string
accountID string
expectedCount int
expectedPeerIDs []string
}{
{
name: "should retrieve peers with expiration for an existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
name: "should retrieve only non-expired peers with expiration enabled",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
expectedPeerIDs: []string{"notexpired01"},
},
{
name: "should return no peers with expiration for a non-existing account ID",
@@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) {
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
for i, peer := range peers {
assert.Equal(t, tt.expectedPeerIDs[i], peer.ID)
}
})
}
}
func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID)
require.NoError(t, err)
// Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned
for _, peer := range peers {
assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned")
assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set")
}
}
func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir())
t.Cleanup(cleanup)
@@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) {
name: "should retrieve peers for another valid account ID and user ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
userID: "edafee4e-63fb-11ec-90d6-0242ac120003",
expectedCount: 2,
expectedCount: 3,
},
{
name: "should return no peers for existing account ID with empty user ID",

View File

@@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300
INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,'');
INSERT INTO installations VALUES(1,'');

View File

@@ -1586,7 +1586,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
@@ -1609,7 +1609,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,

View File

@@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
}
func TestSetSessionCookieHasRootPath(t *testing.T) {
w := httptest.NewRecorder()
setSessionCookie(w, "test-token", time.Hour)
cookies := w.Result().Cookies()
require.Len(t, cookies, 1)
assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths")
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil, nil)
kp := generateTestKeyPair(t)