Compare commits

..

14 Commits

Author SHA1 Message Date
Ashley Mensah
6b8e40f78d initial commit - workflow yaml, prompts and schemas 2026-04-23 18:48:38 +02:00
Zoltan Papp
5da05ecca6 [client] increase gRPC health check timeout to 5s (#5961)
Bump the IsHealthy() context timeout from 1s to 5s for both the
management and signal gRPC clients to reduce false negatives on
slower or congested connections.
2026-04-22 20:54:18 +02:00
Viktor Liu
801de8c68d [client] Add TTL-based refresh to mgmt DNS cache via handler chain (#5945) 2026-04-22 15:10:14 +02:00
Viktor Liu
a822a33240 [self-hosted] Use cscli lapi status for CrowdSec readiness in installer (#5949) 2026-04-22 10:35:22 +02:00
Bethuel Mmbaga
57b23c5b25 [management] Propagate context changes to upstream middleware (#5956) 2026-04-21 23:06:52 +03:00
Zoltan Papp
1165058fad [client] fix port collision in TestUpload (#5950)
* [debug] fix port collision in TestUpload

TestUpload hardcoded :8080, so it failed deterministically when anything
was already on that port and collided across concurrent test runs.
Bind a :0 listener in the test to get a kernel-assigned free port, and
add Server.Serve so tests can hand the listener in without reaching
into unexported state.

* [debug] drop test-only Server.Serve, use SERVER_ADDRESS env

The previous commit added a Server.Serve method on the upload-server,
used only by TestUpload. That left production with an unused function.
Reserve an ephemeral loopback port in the test, release it, and pass
the address through SERVER_ADDRESS (which the server already reads).
A small wait helper ensures the server is accepting connections before
the upload runs, so the close/rebind gap does not cause a false failure.
2026-04-21 19:07:20 +02:00
Zoltan Papp
703353d354 [flow] fix goroutine leak in TestReceive_ProtocolErrorStreamReconnect (#5951)
The Receive goroutine could outlive the test and call t.Logf after
teardown, panicking with "Log in goroutine after ... has completed".
Register a cleanup that waits for the goroutine to exit; ordering is
LIFO so it runs after client.Close, which is what unblocks Receive.
2026-04-21 19:06:47 +02:00
Zoltan Papp
2fb50aef6b [client] allow UDP packet loss in TestICEBind_HandlesConcurrentMixedTraffic (#5953)
The test writes 500 packets per family and asserted exact-count
delivery within a 5s window, even though its own comment says "Some
packet loss is acceptable for UDP". On FreeBSD/QEMU runners the writer
loops cannot always finish all 500 before the 5s deadline closes the
readers (we have seen 411/500 in CI).

The real assertion of this test is the routing check — IPv4 peer only
gets v4- packets, IPv6 peer only gets v6- packets — which remains
strict. Replace the exact-count assertions with a >=80% delivery
threshold so runner speed variance no longer causes false failures.
2026-04-21 19:05:58 +02:00
Vlad
eb3aa96257 [management] check policy for changes before actual db update (#5405) 2026-04-21 18:37:04 +02:00
Viktor Liu
064ec1c832 [client] Trust wg interface in firewalld to bypass owner-flagged chains (#5928) 2026-04-21 17:57:16 +02:00
Viktor Liu
75e408f51c [client] Prefer systemd-resolved stub over file mode regardless of resolv.conf header (#5935) 2026-04-21 17:56:56 +02:00
Zoltan Papp
5a89e6621b [client] Supress ICE signaling (#5820)
* [client] Suppress ICE signaling and periodic offers in force-relay mode

When NB_FORCE_RELAY is enabled, skip WorkerICE creation entirely,
suppress ICE credentials in offer/answer messages, disable the
periodic ICE candidate monitor, and fix isConnectedOnAllWay to
only check relay status so the guard stops sending unnecessary offers.

* [client] Dynamically suppress ICE based on remote peer's offer credentials

Track whether the remote peer includes ICE credentials in its
offers/answers. When remote stops sending ICE credentials, skip
ICE listener dispatch, suppress ICE credentials in responses, and
exclude ICE from the guard connectivity check. When remote resumes
sending ICE credentials, re-enable all ICE behavior.

* [client] Fix nil SessionID panic and force ICE teardown on relay-only transition

Fix nil pointer dereference in signalOfferAnswer when SessionID is nil
(relay-only offers). Close stale ICE agent immediately when remote peer
stops sending ICE credentials to avoid traffic black-hole during the
ICE disconnect timeout.

* [client] Add relay-only fallback check when ICE is unavailable

Ensure the relay connection is supported with the peer when ICE is disabled to prevent connectivity issues.

* [client] Add tri-state connection status to guard for smarter ICE retry (#5828)

* [client] Add tri-state connection status to guard for smarter ICE retry

Refactor isConnectedOnAllWay to return a ConnStatus enum (Connected,
Disconnected, PartiallyConnected) instead of a boolean. When relay is
up but ICE is not (PartiallyConnected), limit ICE offers to 3 retries
with exponential backoff then fall back to hourly attempts, reducing
unnecessary signaling traffic. Fully disconnected peers continue to
retry aggressively. External events (relay/ICE disconnect, signal/relay
reconnect) reset retry state to give ICE a fresh chance.

* [client] Clarify guard ICE retry state and trace log trigger

Split iceRetryState.attempt into shouldRetry (pure predicate) and
enterHourlyMode (explicit state transition) so the caller in
reconnectLoopWithRetry reads top-to-bottom. Restore the original
trace-log behavior in isConnectedOnAllWay so it only logs on full
disconnection, not on the new PartiallyConnected state.

* [client] Extract pure evalConnStatus and add unit tests

Split isConnectedOnAllWay into a thin method that snapshots state and
a pure evalConnStatus helper that takes a connStatusInputs struct, so
the tri-state decision logic can be exercised without constructing
full Worker or Handshaker objects. Add table-driven tests covering
force-relay, ICE-unavailable and fully-available code paths, plus
unit tests for iceRetryState budget/hourly transitions and reset.

* [client] Improve grammar in logs and refactor ICE credential checks
2026-04-21 15:52:08 +02:00
Misha Bragin
06dfa9d4a5 [management] replace mailru/easyjson with netbirdio/easyjson fork (#5938) 2026-04-21 13:59:35 +02:00
Misha Bragin
45d9ee52c0 [self-hosted] add reverse proxy retention fields to combined YAML (#5930) 2026-04-21 10:21:11 +02:00
131 changed files with 5405 additions and 10736 deletions

View File

@@ -0,0 +1,26 @@
You are a GitHub issue resolution classifier.
Your job is to decide whether an open GitHub issue is:
- AUTO_CLOSE
- MANUAL_REVIEW
- KEEP_OPEN
Rules:
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
- a merged linked PR that clearly resolves the issue, or
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
5. Do not invent evidence.
6. Output valid JSON only.
Maintainer-authoritative roles:
- MEMBER
- OWNER
- COLLABORATOR
Important:
- Later comments outweigh earlier ones.
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.

View File

@@ -0,0 +1,78 @@
{
"type": "object",
"additionalProperties": false,
"required": [
"decision",
"reason_code",
"confidence",
"hard_signals",
"contradictions",
"summary",
"close_comment",
"manual_review_note"
],
"properties": {
"decision": {
"type": "string",
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
},
"reason_code": {
"type": "string",
"enum": [
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed",
"likely_fixed_but_unconfirmed",
"still_open",
"unclear"
]
},
"confidence": {
"type": "number",
"minimum": 0,
"maximum": 1
},
"hard_signals": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"merged_pr",
"maintainer_comment",
"duplicate_reference",
"superseded_reference"
]
},
"url": { "type": "string" }
}
}
},
"contradictions": {
"type": "array",
"items": {
"type": "object",
"required": ["type", "url"],
"properties": {
"type": {
"type": "string",
"enum": [
"reporter_still_broken",
"later_unresolved_comment",
"ambiguous_pr_link",
"other"
]
},
"url": { "type": "string" }
}
}
},
"summary": { "type": "string" },
"close_comment": { "type": "string" },
"manual_review_note": { "type": "string" }
}
}

View File

@@ -0,0 +1,152 @@
import fs from "node:fs/promises";
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
const headers = {
Authorization: `Bearer ${process.env.GH_TOKEN}`,
Accept: "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
};
async function rest(url, method = "GET", body) {
const res = await fetch(url, {
method,
headers,
body: body ? JSON.stringify(body) : undefined
});
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
return res.status === 204 ? null : res.json();
}
async function graphql(query, variables) {
const res = await fetch("https://api.github.com/graphql", {
method: "POST",
headers,
body: JSON.stringify({ query, variables })
});
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
const json = await res.json();
if (json.errors) throw new Error(JSON.stringify(json.errors));
return json.data;
}
async function addLabel(owner, repo, issueNumber, labels) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
"POST",
{ labels }
);
}
async function addComment(owner, repo, issueNumber, body) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
"POST",
{ body }
);
}
async function closeIssue(owner, repo, issueNumber) {
return rest(
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
"PATCH",
{ state: "closed", state_reason: "completed" }
);
}
async function getIssueNodeId(owner, repo, issueNumber) {
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
return issue.node_id;
}
async function addToProject(issueNodeId) {
const mutation = `
mutation($projectId: ID!, $contentId: ID!) {
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
item { id }
}
}
`;
const data = await graphql(mutation, {
projectId: process.env.PROJECT_ID,
contentId: issueNodeId
});
return data.addProjectV2ItemById.item.id;
}
async function setTextField(itemId, fieldId, value) {
const mutation = `
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
updateProjectV2ItemFieldValue(input: {
projectId: $projectId,
itemId: $itemId,
fieldId: $fieldId,
value: { text: $value }
}) {
projectV2Item { id }
}
}
`;
return graphql(mutation, {
projectId: process.env.PROJECT_ID,
itemId,
fieldId,
value
});
}
for (const d of decisions) {
const [owner, repo] = d.repository.split("/");
if (d.final_decision === "AUTO_CLOSE") {
if (dryRun) continue;
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
await addComment(
owner,
repo,
d.issue_number,
d.model.close_comment ||
"This appears resolved based on linked evidence, so were closing it automatically. Reply if this still reproduces and well reopen."
);
await closeIssue(owner, repo, d.issue_number);
}
if (d.final_decision === "MANUAL_REVIEW") {
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
const itemId = await addToProject(issueNodeId);
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
}
if (process.env.PROJECT_REASON_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
}
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
}
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
if (linked) {
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
}
}
if (process.env.PROJECT_REPO_FIELD_ID) {
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
}
await addComment(
owner,
repo,
d.issue_number,
d.model.manual_review_note ||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
);
}
}

View File

@@ -0,0 +1,125 @@
import fs from "node:fs/promises";
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
function isMaintainerRole(role) {
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
}
function preScore(candidate) {
let score = 0;
const hardSignals = [];
const contradictions = [];
for (const t of candidate.timeline) {
const sourceIssue = t.source?.issue;
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
hardSignals.push({
type: "merged_pr",
url: sourceIssue.html_url
});
score += 40; // provisional until PR merged state is verified
}
if (["referenced", "connected"].includes(t.event)) {
score += 10;
}
}
for (const c of candidate.comments) {
const body = c.body.toLowerCase();
if (
isMaintainerRole(c.author_association) &&
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
) {
score += 25;
hardSignals.push({
type: "maintainer_comment",
url: c.html_url
});
}
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
score -= 50;
contradictions.push({
type: "later_unresolved_comment",
url: c.html_url
});
}
}
return { score, hardSignals, contradictions };
}
async function callGitHubModel(issuePacket) {
// Replace this stub with the GitHub Models inference call used by your org.
// The workflow already has models: read permission.
return {
decision: "MANUAL_REVIEW",
reason_code: "likely_fixed_but_unconfirmed",
confidence: 0.74,
hard_signals: [],
contradictions: [],
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
close_comment: "This appears resolved, so were closing it automatically. Reply if this is still reproducible.",
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
};
}
function enforcePolicy(modelOut, pre) {
const approvedReasons = new Set([
"resolved_by_merged_pr",
"maintainer_confirmed_resolved",
"duplicate_confirmed",
"superseded_confirmed"
]);
const hasHardSignal =
(modelOut.hard_signals || []).some(s =>
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
) || pre.hardSignals.length > 0;
const hasContradiction =
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
if (
modelOut.decision === "AUTO_CLOSE" &&
modelOut.confidence >= 0.97 &&
approvedReasons.has(modelOut.reason_code) &&
hasHardSignal &&
!hasContradiction
) {
return "AUTO_CLOSE";
}
if (
modelOut.decision === "MANUAL_REVIEW" ||
modelOut.confidence >= 0.60 ||
pre.score >= 25
) {
return "MANUAL_REVIEW";
}
return "KEEP_OPEN";
}
const decisions = [];
for (const candidate of candidates) {
const pre = preScore(candidate);
const modelOut = await callGitHubModel(candidate);
const finalDecision = enforcePolicy(modelOut, pre);
decisions.push({
repository: candidate.repository,
issue_number: candidate.issue.number,
issue_url: candidate.issue.html_url,
title: candidate.issue.title,
pre_score: pre.score,
final_decision: finalDecision,
model: modelOut
});
}
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));

View File

@@ -0,0 +1,50 @@
name: issue-resolution-triage
on:
workflow_dispatch:
inputs:
dry_run:
description: "If true, do not close issues"
required: false
default: "true"
max_issues:
description: "How many issues to process"
required: false
default: "100"
schedule:
- cron: "17 2 * * *"
permissions:
contents: read
issues: write
pull-requests: read
models: read
jobs:
triage:
runs-on: ubuntu-latest
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DRY_RUN: ${{ inputs.dry_run || 'true' }}
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
REPO: ${{ github.repository }}
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: "20"
- run: npm ci
- run: node scripts/fetch-candidates.mjs
- run: node scripts/classify-candidates.mjs
- run: node scripts/apply-decisions.mjs

View File

@@ -151,7 +151,6 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(vncCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -36,10 +36,7 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
jwtCacheTTLFlag = "jwt-cache-ttl"
// Alias for backward compatibility.
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var (
@@ -64,7 +61,7 @@ var (
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
jwtCacheTTL int
sshJWTCacheTTL int
)
func init() {
@@ -74,9 +71,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)

View File

@@ -356,9 +356,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -374,12 +371,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(disableVNCAuthFlag).Changed {
req.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
req.SshJWTCacheTTL = &jwtCacheTTL32
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -464,9 +458,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -488,12 +479,8 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(disableVNCAuthFlag).Changed {
ic.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &jwtCacheTTL
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
@@ -595,9 +582,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
@@ -619,13 +603,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(disableVNCAuthFlag).Changed {
loginRequest.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {

View File

@@ -1,271 +0,0 @@
package cmd
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
vncUsername string
vncHost string
vncMode string
vncListen string
vncNoBrowser bool
vncNoCache bool
)
func init() {
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
}
var vncCmd = &cobra.Command{
Use: "vnc [flags] [user@]host",
Short: "Connect to a NetBird peer via VNC",
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
The target peer must have the VNC server enabled.
Two modes are available:
- attach: view the current physical display (remote support)
- session: start a virtual desktop as the specified user (passwordless login)
Use --listen to start a local proxy for external VNC viewers:
netbird vnc --listen :5900 peer-hostname
vncviewer localhost:5900
Examples:
netbird vnc peer-hostname
netbird vnc --mode session --user alice peer-hostname
netbird vnc --listen :5900 peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: vncFn,
}
func vncFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
if err := parseVNCHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
vncCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runVNC(vncCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-vncCtx.Done()
return nil
case err := <-errCh:
return err
case <-vncCtx.Done():
}
return nil
}
func parseVNCHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return fmt.Errorf("invalid user@host format")
}
if vncUsername == "" {
vncUsername = parts[0]
}
vncHost = parts[1]
if vncMode == "attach" {
vncMode = "session"
}
} else {
vncHost = arg
}
if vncMode == "session" && vncUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
vncUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
vncUsername = currentUser.Username
}
}
return nil
}
func runVNC(ctx context.Context, cmd *cobra.Command) error {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
if vncMode == "session" {
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
} else {
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
}
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
var jwtToken string
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !vncNoBrowser {
browserOpener = util.OpenBrowser
}
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
if err != nil {
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
} else {
jwtToken = token
log.Debug("JWT authentication successful")
}
// Connect to the VNC server on the standard port (5900). The peer's firewall
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
vncAddr := net.JoinHostPort(vncHost, "5900")
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
if err != nil {
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
}
defer vncConn.Close()
// Send session header with mode, username, and JWT.
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
return fmt.Errorf("send VNC header: %w", err)
}
cmd.Printf("VNC connected to %s\n", vncHost)
if vncListen != "" {
return runVNCLocalProxy(ctx, cmd, vncConn)
}
// No --listen flag: inform the user they need to use --listen for external viewers.
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
cmd.Printf("Press Ctrl+C to disconnect.\n")
<-ctx.Done()
return nil
}
const vncDialTimeout = 15 * time.Second
// sendVNCHeader writes the NetBird VNC session header.
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
var modeByte byte
if mode == "session" {
modeByte = 1
}
usernameBytes := []byte(username)
jwtBytes := []byte(jwt)
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
hdr[0] = modeByte
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
off += 2
copy(hdr[off:], jwtBytes)
_, err := conn.Write(hdr)
return err
}
// runVNCLocalProxy listens on the given address and proxies incoming
// connections to the already-established VNC tunnel.
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
listener, err := net.Listen("tcp", vncListen)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncListen, err)
}
defer listener.Close()
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
cmd.Printf("Press Ctrl+C to stop.\n")
go func() {
<-ctx.Done()
listener.Close()
}()
// Accept a single viewer connection. VNC is single-session: the RFB
// handshake completes on vncConn for the first viewer, so subsequent
// viewers would get a mid-stream connection. The loop handles transient
// accept errors until a valid connection arrives.
for {
clientConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
}
log.Debugf("accept VNC proxy client: %v", err)
continue
}
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
// Bidirectional copy.
done := make(chan struct{})
go func() {
io.Copy(vncConn, clientConn)
close(done)
}()
io.Copy(clientConn, vncConn)
<-done
clientConn.Close()
cmd.Printf("VNC viewer disconnected\n")
return nil
}
}

View File

@@ -1,62 +0,0 @@
//go:build windows
package cmd
import (
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var vncAgentPort string
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server in the current user session, listening on
// localhost. It is spawned by the NetBird service (Session 0) via
// CreateProcessAsUser into the interactive console session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Agent's stderr is piped to the service which relogs it.
// Use JSON format with caller info for structured parsing.
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
capturer := vncserver.NewDesktopCapturer()
injector := vncserver.NewWindowsInputInjector()
srv := vncserver.New(capturer, injector, "")
// Auth is handled by the service. The agent verifies a token on each
// connection to ensure only the service process can connect.
// The token is passed via environment variable to avoid exposing it
// in the process command line (visible via tasklist/wmic).
srv.SetDisableAuth(true)
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
if err != nil {
return err
}
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
return err
}
<-cmd.Context().Done()
return srv.Stop()
},
}

View File

@@ -1,16 +0,0 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCAuthFlag = "disable-vnc-auth"
)
var (
serverVNCAllowed bool
disableVNCAuth bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
}

View File

@@ -1,229 +0,0 @@
package cmd
import (
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
"github.com/netbirdio/netbird/util"
)
var vncRecDir string
func init() {
vncRecPlayCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
vncRecListCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
vncRecCmd.AddCommand(vncRecListCmd)
vncRecCmd.AddCommand(vncRecPlayCmd)
vncRecCmd.AddCommand(vncRecKeygenCmd)
vncCmd.AddCommand(vncRecCmd)
}
var vncRecCmd = &cobra.Command{
Use: "rec",
Short: "Manage VNC session recordings",
}
var vncRecKeygenCmd = &cobra.Command{
Use: "keygen",
Short: "Generate an X25519 keypair for recording encryption",
Long: `Generates an X25519 keypair. Put the public key in management settings
(Session Recording > Encryption Key). Keep the private key safe for decrypting recordings.`,
RunE: vncRecKeygenFn,
}
var vncRecListCmd = &cobra.Command{
Use: "list",
Short: "List VNC session recordings",
RunE: vncRecListFn,
}
var vncRecPlayCmd = &cobra.Command{
Use: "play <file-or-name>",
Short: "Open a VNC recording in the browser",
Long: `Opens a browser-based player with playback controls:
play/pause, seek, speed (0.25x to 8x), keyboard shortcuts.
Examples:
netbird vnc rec play last
netbird vnc rec play 20260416-104433_vnc.rec`,
Args: cobra.ExactArgs(1),
RunE: vncRecPlayFn,
}
func vncRecListFn(cmd *cobra.Command, _ []string) error {
dir, err := resolveVNCRecDir()
if err != nil {
return err
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("read recording dir %s: %w", dir, err)
}
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "FILE\tSIZE\tDIMENSIONS\tUSER\tREMOTE\tMODE\tDATE")
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
continue
}
filePath := filepath.Join(dir, entry.Name())
info, err := entry.Info()
if err != nil {
continue
}
header, err := vncserver.ReadRecordingHeader(filePath)
if err != nil {
fmt.Fprintf(w, "%s\t%s\t?\t?\t?\t?\t?\n", entry.Name(), vncFormatSize(info.Size()))
continue
}
fmt.Fprintf(w, "%s\t%s\t%dx%d\t%s\t%s\t%s\t%s\n",
entry.Name(),
vncFormatSize(info.Size()),
header.Width, header.Height,
header.Meta.User,
header.Meta.RemoteAddr,
header.Meta.Mode,
header.StartTime.Format("2006-01-02 15:04:05"),
)
}
return w.Flush()
}
func vncRecPlayFn(cmd *cobra.Command, args []string) error {
filePath, err := resolveVNCRecFile(args[0])
if err != nil {
return err
}
header, err := vncserver.ReadRecordingHeader(filePath)
if err != nil {
return fmt.Errorf("read recording: %w", err)
}
cmd.Printf("Recording: %s (%dx%d)\n", filepath.Base(filePath), header.Width, header.Height)
url, err := vncserver.ServeWebPlayer(filePath, "localhost:0")
if err != nil {
return fmt.Errorf("start web player: %w", err)
}
cmd.Printf("Player: %s\n", url)
if err := util.OpenBrowser(url); err != nil {
cmd.Printf("Open %s in your browser\n", url)
}
cmd.Printf("Press Ctrl+C to stop.\n")
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
return nil
}
func vncRecKeygenFn(cmd *cobra.Command, _ []string) error {
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("generate key: %w", err)
}
privB64 := base64.StdEncoding.EncodeToString(priv.Bytes())
pubB64 := base64.StdEncoding.EncodeToString(priv.PublicKey().Bytes())
cmd.Printf("Private key (keep secret, for decrypting recordings):\n %s\n\n", privB64)
cmd.Printf("Public key (paste into management Settings > Session Recording > Encryption Key):\n %s\n", pubB64)
return nil
}
func vncFormatSize(size int64) string {
switch {
case size >= 1<<20:
return fmt.Sprintf("%.1fM", float64(size)/float64(1<<20))
case size >= 1<<10:
return fmt.Sprintf("%.1fK", float64(size)/float64(1<<10))
default:
return fmt.Sprintf("%dB", size)
}
}
func resolveVNCRecDir() (string, error) {
if vncRecDir != "" {
return vncRecDir, nil
}
candidates := []string{
"/var/lib/netbird/recordings/vnc",
filepath.Join(os.Getenv("HOME"), ".netbird/recordings/vnc"),
}
for _, dir := range candidates {
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
return dir, nil
}
}
return "", fmt.Errorf("no VNC recording directory found; use --dir to specify")
}
func resolveVNCRecFile(arg string) (string, error) {
if strings.Contains(arg, "/") || strings.Contains(arg, string(os.PathSeparator)) {
return arg, nil
}
dir, err := resolveVNCRecDir()
if err != nil && arg != "last" {
return arg, nil
}
if arg == "last" {
if err != nil {
return "", err
}
return findLatestRec(dir)
}
full := filepath.Join(dir, arg)
if _, err := os.Stat(full); err == nil {
return full, nil
}
return arg, nil
}
func findLatestRec(dir string) (string, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return "", fmt.Errorf("read dir: %w", err)
}
var latest string
var latestTime time.Time
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.ModTime().After(latestTime) {
latestTime = info.ModTime()
latest = filepath.Join(dir, entry.Name())
}
}
if latest == "" {
return "", fmt.Errorf("no recordings found in %s", dir)
}
return latest, nil
}

View File

@@ -0,0 +1,11 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -0,0 +1,260 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -0,0 +1,49 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,25 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
}
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -14,6 +14,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -19,6 +19,7 @@ import (
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
@@ -40,6 +41,8 @@ const (
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false

View File

@@ -3,6 +3,9 @@
package uspfilter
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
}
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil
}
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device

View File

@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
NameFunc func() string
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) Name() string {
if i.NameFunc == nil {
return "wgtest"
}
return i.NameFunc()
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil

View File

@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++
}
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
// routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
@@ -328,7 +327,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
a.config.DisableVNCAuth,
)
}

View File

@@ -546,13 +546,11 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DisableVNCAuth: config.DisableVNCAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
@@ -629,7 +627,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
@@ -642,7 +639,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
config.DisableVNCAuth,
)
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}

View File

@@ -3,10 +3,12 @@ package debug
import (
"context"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci")
}
testDir := t.TempDir()
testURL := "http://localhost:8080"
addr := reserveLoopbackPort(t)
testURL := "http://" + addr
t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir)
srv := server.NewServer()
go func() {
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err)
}
})
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content")
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent)
}
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -13,6 +13,7 @@ import (
const (
defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
)
type resolvConf struct {

View File

@@ -1,7 +1,10 @@
package dns
import (
"context"
"fmt"
"math"
"net"
"slices"
"strconv"
"strings"
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 {
return
}
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// Try handlers in priority order
for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) {
continue
}
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime))
}
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
}
}
}
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,11 +1,15 @@
package dns_test
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
})
}
}
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
}
func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType()
osManager, reason, err := getOSDNSManagerType()
if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s", osManager)
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
}
}
func getOSDNSManagerType() (osManagerType, error) {
func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
}
defer func() {
if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
}
}()
var rejected []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
continue
}
if text[0] != '#' {
return fileManager, nil
break
}
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return mgr, reason, nil, nil
} else if rej != "" {
rejected = append(rejected, rej)
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
return 0, fmt.Errorf("scan: %w", err)
return 0, "", nil, fmt.Errorf("scan: %w", err)
}
return fileManager, nil
return 0, "", rejected, nil
}
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
// matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
func checkStub() bool {
rConf, err := parseDefaultResolvConf()
if err != nil {
log.Warnf("failed to parse resolv conf: %s", err)
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
return true
}
@@ -139,3 +178,36 @@ func checkStub() bool {
return false
}
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,76 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,40 +2,83 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain"
)
const dnsTimeout = 5 * time.Second
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
// Resolver caches critical NetBird infrastructure domains
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
records map[dns.Question][]dns.RR
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
type ipsResponse struct {
ips []netip.Addr
err error
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
}
}
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// ServeDNS implements dns.Handler interface.
// SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
records, found := m.records[question]
cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
m.mutex.RUnlock()
if !found {
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = append(resp.Answer, records...)
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
}
}
// AddDomain manually adds a domain to cache by resolving it.
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return err
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
now := time.Now()
m.mutex.Lock()
defer m.mutex.Unlock()
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
// untouched on error. Caller holds m.mutex.
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
}
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err
}
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
return resp.ips, nil
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
}
// PopulateFromConfig extracts and caches domains from the client configuration.
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
delete(m.records, qA)
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -0,0 +1,408 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,6 +6,7 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains())
}
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -212,6 +212,7 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
ctx: ctx,

View File

@@ -26,6 +26,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@@ -117,13 +118,11 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
DNSRouteInterval time.Duration
@@ -200,7 +199,6 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
statusRecorder *peer.Status
@@ -314,10 +312,6 @@ func (e *Engine) Stop() error {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -577,7 +571,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -611,6 +605,8 @@ func (e *Engine) createFirewall() error {
return nil
}
firewalld.SetParentContext(e.ctx)
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil {
@@ -1005,7 +1001,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1018,7 +1013,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -1046,10 +1040,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1152,7 +1142,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1165,7 +1154,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1340,11 +1328,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// VNC auth: use dedicated VNCAuth if present.
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
e.updateVNCServerAuth(vncAuth)
}
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1754,7 +1737,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1767,7 +1749,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -1,309 +0,0 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const envVNCForceRecording = "NB_VNC_FORCE_RECORDING"
const (
vncExternalPort uint16 = 5900
vncInternalPort uint16 = 25900
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
// sshConf provides the JWT identity provider config (shared with SSH).
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
// Update JWT config on existing server in case management sent new config.
e.updateVNCServerJWT(sshConf)
return nil
}
return e.startVNCServer(sshConf)
}
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector := newPlatformVNC()
if capturer == nil || injector == nil {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
srv := vncserver.New(capturer, injector, "")
if vncNeedsServiceMode() {
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
srv.SetServiceMode(true)
}
// Configure VNC authentication.
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
log.Info("VNC: authentication disabled by config")
srv.SetDisableAuth(true)
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
srv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
}
e.configureVNCRecording(srv, sshConf)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
srv.SetNetstackNet(netstackNet)
}
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// configureVNCRecording enables session recording on the VNC server from the
// management-supplied settings. The env var NB_VNC_FORCE_RECORDING overrides
// the API for local development: when set, recording is always enabled and
// writes into that directory. Otherwise recordings go next to the state file
// under vnc-recordings/.
func (e *Engine) configureVNCRecording(srv *vncserver.Server, sshConf *mgmProto.SSHConfig) {
recDir := os.Getenv(envVNCForceRecording)
apiEnabled := sshConf.GetEnableRecording()
if recDir == "" && !apiEnabled {
log.Debugf("VNC recording disabled (env=%q, api=%v)", recDir, apiEnabled)
return
}
if recDir == "" {
base := e.defaultRecordingBase()
if base == "" {
log.Warn("VNC recording requested by management but no state directory is available")
return
}
recDir = filepath.Join(base, "vnc-recordings")
} else {
recDir = filepath.Join(recDir, "vnc")
}
srv.SetRecordingDir(recDir)
log.Infof("VNC recording enabled (dir=%s, source=%s)", recDir, recordingSource(apiEnabled))
encKey := string(sshConf.GetRecordingEncryptionKey())
if encKey == "" {
encKey = os.Getenv("NB_VNC_RECORDING_ENCRYPTION_KEY")
}
if encKey != "" {
srv.SetRecordingEncryptionKey(encKey)
log.Info("VNC recording encryption enabled")
}
}
func (e *Engine) defaultRecordingBase() string {
if e.stateManager == nil {
return ""
}
p := e.stateManager.FilePath()
if p == "" {
return ""
}
return filepath.Dir(p)
}
func recordingSource(api bool) string {
if api {
return "management"
}
return "env"
}
// updateVNCServerJWT configures the JWT validation for the VNC server using
// the same JWT config as SSH (same identity provider).
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
vncSrv.SetDisableAuth(true)
return
}
protoJWT := sshConf.GetJwtConfig()
if protoJWT == nil {
return
}
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if vncAuth == nil || e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
UserIDClaim: vncAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
})
}
// GetVNCServerStatus returns whether the VNC server is running.
func (e *Engine) GetVNCServerStatus() bool {
return e.vncSrv != nil
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}

View File

@@ -1,23 +0,0 @@
//go:build darwin && !ios
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -1,13 +0,0 @@
//go:build !windows && !darwin && !freebsd && !(linux && !android)
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return nil, nil
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -1,13 +0,0 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -1,23 +0,0 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewX11Poller("")
injector, err := vncserver.NewX11InputInjector("")
if err != nil {
log.Debugf("VNC: X11 input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
forceRelay := IsForceRelayed()
if !forceRelay {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
if !forceRelay {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
}
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
conn.workerICE.Close()
if conn.workerICE != nil {
conn.workerICE.Close()
}
if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn()
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
if conn.workerICE != nil {
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusConnecting
}
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
// would be better to protect this with a mutex, but it could cause deadlock with Close function
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
//
// The result is a tri-state:
// - ConnStatusConnected: all available transports are up
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
// - ConnStatusDisconnected: no working transport
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
defer func() {
if !connected {
if status == guard.ConnStatusDisconnected {
conn.logTraceConnState()
}
}()
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
iceWorkerCreated := conn.workerICE != nil
var iceInProgress bool
if iceWorkerCreated {
iceInProgress = conn.workerICE.InProgress()
}
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false
}
}
return true
return evalConnStatus(connStatusInputs{
forceRelay: IsForceRelayed(),
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
iceWorkerCreated: iceWorkerCreated,
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
iceInProgress: iceInProgress,
})
}
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
// "Relay up and needed" — the peer uses relay and the transport is connected.
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
if in.forceRelay {
return boolToConnStatus(relayUsedAndUp)
}
// Remote peer doesn't support ICE, or we haven't created the worker yet:
// relay is the only possible transport.
if !in.remoteSupportsICE || !in.iceWorkerCreated {
return boolToConnStatus(relayUsedAndUp)
}
// ICE counts as "up" when the status is anything other than Disconnected, OR
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
iceUp := in.iceStatusConnecting || in.iceInProgress
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
relayOK := !in.peerUsesRelay || in.relayConnected
switch {
case iceUp && relayOK:
return guard.ConnStatusConnected
case relayUsedAndUp:
// Relay is up but ICE is down — partially connected.
return guard.ConnStatusPartiallyConnected
default:
return guard.ConnStatusDisconnected
}
}
func boolToConnStatus(connected bool) guard.ConnStatus {
if connected {
return guard.ConnStatusConnected
}
return guard.ConnStatusDisconnected
}

View File

@@ -13,6 +13,20 @@ const (
StatusConnected
)
// connStatusInputs is the primitive-valued snapshot of the state that drives the
// tri-state connection classification. Extracted so the decision logic can be unit-tested
// without constructing full Worker/Handshaker objects.
type connStatusInputs struct {
forceRelay bool // NB_FORCE_RELAY or JS/WASM
peerUsesRelay bool // remote peer advertises relay support AND local has relay
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
remoteSupportsICE bool // remote peer sent ICE credentials
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
iceStatusConnecting bool // statusICE is anything other than Disconnected
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection
type ConnStatus int32

View File

@@ -0,0 +1,201 @@
package peer
import (
"testing"
"github.com/netbirdio/netbird/client/internal/peer/guard"
)
func TestEvalConnStatus_ForceRelay(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "force relay, peer uses relay, relay up",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: true,
},
want: guard.ConnStatusConnected,
},
{
name: "force relay, peer uses relay, relay down",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: false,
},
want: guard.ConnStatusDisconnected,
},
{
name: "force relay, peer does NOT use relay - disconnected forever",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: false,
relayConnected: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "remote does not support ICE, peer uses relay, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer uses relay, relay down",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE worker not yet created, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: true,
iceWorkerCreated: false,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer does not use relay",
in: connStatusInputs{
peerUsesRelay: false,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
base := connStatusInputs{
remoteSupportsICE: true,
iceWorkerCreated: true,
}
tests := []struct {
name string
mutator func(*connStatusInputs)
want guard.ConnStatus
}{
{
name: "ICE connected, relay connected, peer uses relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE connected, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE InProgress only, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.iceStatusConnecting = false
in.iceInProgress = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE down, relay up, peer uses relay -> partial",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusPartiallyConnected,
},
{
name: "ICE down, peer does NOT use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = false
in.iceStatusConnecting = true
},
// relayOK = false (peer uses relay but it's down), iceUp = true
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
// falls into default: Disconnected.
want: guard.ConnStatusDisconnected,
},
{
name: "ICE down, relay up but peer does not use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = true // not actually used since peer doesn't rely on it
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
in := base
tc.mutator(&in)
if got := evalConnStatus(in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
}
})
}
}

View File

@@ -10,7 +10,7 @@ const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
func IsForceRelayed() bool {
if runtime.GOOS == "js" {
return true
}

View File

@@ -8,7 +8,19 @@ import (
log "github.com/sirupsen/logrus"
)
type isConnectedFunc func() bool
// ConnStatus represents the connection state as seen by the guard.
type ConnStatus int
const (
// ConnStatusDisconnected means neither ICE nor Relay is connected.
ConnStatusDisconnected ConnStatus = iota
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
ConnStatusPartiallyConnected
// ConnStatusConnected means all required connections are established.
ConnStatusConnected
)
type connStatusFunc func() ConnStatus
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
// - ICE candidate changes
type Guard struct {
log *log.Entry
isConnectedOnAllWay isConnectedFunc
isConnectedOnAllWay connStatusFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{}
}
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
log: log,
isConnectedOnAllWay: isConnectedFn,
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
}
}
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
//
// Behavior depends on the connection state reported by isConnectedOnAllWay:
// - Connected: no action, the peer is fully reachable.
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
//
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
tickerChannel := ticker.C
iceState := &iceRetryState{log: g.log}
defer iceState.reset()
for {
select {
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
case <-tickerChannel:
switch g.isConnectedOnAllWay() {
case ConnStatusConnected:
// all good, nothing to do
case ConnStatusDisconnected:
callback()
case ConnStatusPartiallyConnected:
if iceState.shouldRetry() {
callback()
} else {
iceState.enterHourlyMode()
ticker.Stop()
tickerChannel = iceState.hourlyC()
}
}
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.prepareExponentTicker(ctx)
ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,

View File

@@ -0,0 +1,61 @@
package guard
import (
"time"
log "github.com/sirupsen/logrus"
)
const (
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
maxICERetries = 3
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
iceRetryInterval = 1 * time.Hour
)
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
// After maxICERetries attempts it switches to a periodic hourly retry.
type iceRetryState struct {
log *log.Entry
retries int
hourly *time.Ticker
}
func (s *iceRetryState) reset() {
s.retries = 0
if s.hourly != nil {
s.hourly.Stop()
s.hourly = nil
}
}
// shouldRetry reports whether the caller should send another ICE offer on this tick.
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
// to the hourly ticker via enterHourlyMode + hourlyC.
func (s *iceRetryState) shouldRetry() bool {
if s.hourly != nil {
s.log.Debugf("hourly ICE retry attempt")
return true
}
s.retries++
if s.retries <= maxICERetries {
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
return true
}
return false
}
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
func (s *iceRetryState) enterHourlyMode() {
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
s.hourly = time.NewTicker(iceRetryInterval)
}
func (s *iceRetryState) hourlyC() <-chan time.Time {
if s.hourly == nil {
return nil
}
return s.hourly.C
}

View File

@@ -0,0 +1,103 @@
package guard
import (
"testing"
log "github.com/sirupsen/logrus"
)
func newTestRetryState() *iceRetryState {
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
}
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
s := newTestRetryState()
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
}
}
}
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries; i++ {
_ = s.shouldRetry()
}
if s.shouldRetry() {
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
}
}
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
s := newTestRetryState()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
}
}
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
defer s.reset()
if s.hourlyC() == nil {
t.Fatalf("hourlyC returned nil after enterHourlyMode")
}
}
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
s := newTestRetryState()
s.enterHourlyMode()
defer s.reset()
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false in hourly mode, want true")
}
// Subsequent calls also return true — we keep retrying on each hourly tick.
if !s.shouldRetry() {
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
}
}
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
s.reset()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel after reset")
}
if s.retries != 0 {
t.Fatalf("retries = %d after reset, want 0", s.retries)
}
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
}
}
}
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
s := newTestRetryState()
s.reset()
s.reset() // second call must not panic or re-stop a nil ticker
if s.hourlyC() != nil {
t.Fatalf("hourlyC non-nil after double reset")
}
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw
}
func (w *SRWatcher) Start() {
func (w *SRWatcher) Start(disableICEMonitor bool) {
w.mu.Lock()
defer w.mu.Unlock()
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
if !disableICEMonitor {
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
}
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
@@ -43,6 +44,10 @@ type OfferAnswer struct {
SessionID *ICESessionID
}
func (o *OfferAnswer) hasICECredentials() bool {
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
}
type Handshaker struct {
mu sync.Mutex
log *log.Entry
@@ -59,6 +64,10 @@ type Handshaker struct {
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -66,7 +75,7 @@ type Handshaker struct {
}
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
return &Handshaker{
h := &Handshaker{
log: log,
config: config,
signaler: signaler,
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
}
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
for {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer)
}
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue
}
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer)
}
case <-ctx.Done():
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
}
if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
}
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
return answer
}
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.hasICECredentials()
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
var sessionIDBytes []byte
if offerAnswer.SessionID != nil {
var err error
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,

View File

@@ -64,13 +64,11 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
@@ -116,13 +114,11 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
@@ -419,21 +415,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.True()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
@@ -484,16 +465,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
if *input.DisableVNCAuth {
log.Infof("disabling VNC authentication")
} else {
log.Infof("enabling VNC authentication")
}
config.DisableVNCAuth = input.DisableVNCAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL

View File

@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
}
}
// FilePath returns the path of the underlying state file.
func (m *Manager) FilePath() string {
if m == nil {
return ""
}
return m.filePath
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {

File diff suppressed because it is too large Load Diff

View File

@@ -209,9 +209,6 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverVNCAllowed = 41;
optional bool disableVNCAuth = 42;
}
message LoginResponse {
@@ -319,10 +316,6 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverVNCAllowed = 28;
bool disableVNCAuth = 29;
}
// PeerState contains the latest state of a peer
@@ -401,11 +394,6 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -420,7 +408,6 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -690,9 +677,6 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverVNCAllowed = 36;
optional bool disableVNCAuth = 37;
}
message SetConfigResponse{}

View File

@@ -369,7 +369,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -386,9 +385,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.DisableVNCAuth != nil {
config.DisableVNCAuth = msg.DisableVNCAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
@@ -1127,7 +1123,6 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1167,26 +1162,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
return &proto.VNCServerState{
Enabled: engine.GetVNCServerStatus(),
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1528,11 +1503,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableSSHAuth = *cfg.DisableSSHAuth
}
disableVNCAuth := false
if cfg.DisableVNCAuth != nil {
disableVNCAuth = *cfg.DisableVNCAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
@@ -1547,7 +1517,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
@@ -1563,7 +1532,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
DisableVNCAuth: disableVNCAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}

View File

@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCAuth := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -84,8 +82,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCAuth: &disableVNCAuth,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -129,10 +125,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCAuth)
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -184,8 +176,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCAuth": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -246,8 +236,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-auth": "DisableVNCAuth",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
}
}
// generateS4UUserToken creates a Windows token using S4U authentication.
// This is the same approach OpenSSH for Windows uses for public key authentication.
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)

View File

@@ -507,7 +507,27 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
maxTokenAge = DefaultJWTMaxTokenAge
}
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
@@ -538,7 +558,27 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
}
func extractUserID(token *gojwt.Token) string {
return jwt.UserIDFromToken(token)
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {

View File

@@ -130,10 +130,6 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -155,7 +151,6 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -176,9 +171,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := VNCServerStateOutput{
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
}
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -202,7 +194,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -533,11 +524,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncServerStatus = "Enabled"
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -567,7 +553,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -585,7 +570,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,

View File

@@ -398,9 +398,6 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false
}
}`
// @formatter:on
@@ -508,8 +505,6 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
`
assert.Equal(t, expectedYAML, yaml)
@@ -577,7 +572,6 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -602,7 +596,6 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -62,7 +62,6 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -78,27 +77,21 @@ type Info struct {
EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
DisableVNCAuth bool
}
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
disableVNCAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes
@@ -124,9 +117,6 @@ func (i *Info) SetFlags(
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
if disableVNCAuth != nil {
i.DisableVNCAuth = *disableVNCAuth
}
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context

View File

@@ -1,474 +0,0 @@
//go:build windows
package server
import (
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
agentPort = "15900"
// agentTokenLen is the length of the random authentication token
// used to verify that connections to the agent come from the service.
agentTokenLen = 32
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
type wtsSessionInfo struct {
SessionID uint32
WinStationName [66]byte // actually *uint16, but we just need the struct size
State uint32
}
// getActiveSessionID returns the session ID of the best session to attach to.
// Prefers an active (logged-in, interactive) session over the console session.
// This avoids kicking out an RDP user when the console is at the login screen.
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer procWTSFreeMemory.Call(sessionInfo)
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
// Find the first active session (not session 0, which is the services session).
var bestID uint32
found := false
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State == wtsActive {
bestID = s.SessionID
found = true
break
}
}
if !found {
return getConsoleSessionID()
}
return bestID
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns a new block pointer with the entry added.
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return uintptr(unsafe.Pointer(&newBlock[0]))
}
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, _ := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r != 0 {
defer procDestroyEnvironmentBlock.Call(envBlock)
}
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic).
if r != 0 {
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
}
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow,
envPtr, nil, &si, &pi,
)
// Close the write end in the parent so reads will get EOF when the child exits.
windows.CloseHandle(stderrWrite)
if err != nil {
windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one.
type sessionManager struct {
port string
mu sync.Mutex
agentProc windows.Handle
sessionID uint32
authToken string
done chan struct{}
}
func newSessionManager(port string) *sessionManager {
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
}
// generateAuthToken creates a new random hex token for agent authentication.
func generateAuthToken() string {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
log.Warnf("generate agent auth token: %v", err)
return ""
}
return hex.EncodeToString(b)
}
// AuthToken returns the current agent authentication token.
func (m *sessionManager) AuthToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.authToken
}
// Stop signals the session manager to exit its polling loop.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
sid := getActiveSessionID()
m.mu.Lock()
if sid != m.sessionID {
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
if m.agentProc != 0 {
var code uint32
_ = windows.GetExitCodeProcess(m.agentProc, &code)
if code != stillActive {
log.Infof("agent exited (code=%d), respawning", code)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
}
}
if m.agentProc == 0 && sid != 0xFFFFFFFF {
m.authToken = generateAuthToken()
h, err := spawnAgentInSession(sid, m.port, m.authToken)
if err != nil {
log.Warnf("spawn agent in session %d: %v", sid, err)
m.authToken = ""
} else {
m.agentProc = h
}
}
m.mu.Unlock()
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
func (m *sessionManager) killAgent() {
if m.agentProc != 0 {
_ = windows.TerminateProcess(m.agentProc, 0)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
log.Info("killed old agent")
}
}
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
// relogs them at the correct level with the service's formatter.
func relogAgentOutput(pipe windows.Handle) {
defer windows.CloseHandle(pipe)
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer f.Close()
entry := log.WithField("component", "vnc-agent")
dec := json.NewDecoder(f)
for dec.More() {
var m map[string]any
if err := dec.Decode(&m); err != nil {
break
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
// Forward extra fields from the agent (skip standard logrus fields).
// Remap "caller" to "source" so it doesn't conflict with logrus internals
// but still shows the original file/line from the agent process.
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// proxyToAgent connects to the agent, sends the auth token, then proxies
// the VNC client connection bidirectionally.
func proxyToAgent(client net.Conn, port string, authToken string) {
defer client.Close()
addr := "127.0.0.1:" + port
var agentConn net.Conn
var err error
for range 50 {
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
if err == nil {
break
}
time.Sleep(200 * time.Millisecond)
}
if err != nil {
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
return
}
defer agentConn.Close()
// Send the auth token so the agent can verify this connection
// comes from the trusted service process.
tokenBytes, _ := hex.DecodeString(authToken)
if _, err := agentConn.Write(tokenBytes); err != nil {
log.Warnf("send auth token to agent: %v", err)
return
}
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client→agent", agentConn, client)
go cp("agent→client", client, agentConn)
<-done
}

View File

@@ -1,486 +0,0 @@
//go:build darwin && !ios
package server
import (
"errors"
"fmt"
"hash/maphash"
"image"
"os"
"runtime"
"strconv"
"sync"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var darwinCaptureOnce sync.Once
var (
cgMainDisplayID func() uint32
cgDisplayPixelsWide func(uint32) uintptr
cgDisplayPixelsHigh func(uint32) uintptr
cgDisplayCreateImage func(uint32) uintptr
cgImageGetWidth func(uintptr) uintptr
cgImageGetHeight func(uintptr) uintptr
cgImageGetBytesPerRow func(uintptr) uintptr
cgImageGetBitsPerPixel func(uintptr) uintptr
cgImageGetDataProvider func(uintptr) uintptr
cgDataProviderCopyData func(uintptr) uintptr
cgImageRelease func(uintptr)
cfDataGetLength func(uintptr) int64
cfDataGetBytePtr func(uintptr) uintptr
cfRelease func(uintptr)
cgPreflightScreenCaptureAccess func() bool
cgRequestScreenCaptureAccess func() bool
darwinCaptureReady bool
)
func initDarwinCapture() {
darwinCaptureOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation: %v", err)
return
}
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
}
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
}
darwinCaptureReady = true
})
}
// errFrameUnchanged signals that the raw capture bytes matched the previous
// frame, so the caller can skip the expensive BGRA to RGBA conversion.
var errFrameUnchanged = errors.New("frame unchanged")
// CGCapturer captures the macOS main display using Core Graphics.
type CGCapturer struct {
displayID uint32
w, h int
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
downscale int
hashSeed maphash.Seed
lastHash uint64
hasHash bool
}
// NewCGCapturer creates a screen capturer for the main display.
func NewCGCapturer() (*CGCapturer, error) {
initDarwinCapture()
if !darwinCaptureReady {
return nil, fmt.Errorf("CoreGraphics not available")
}
// Request Screen Recording permission (shows system dialog on macOS 11+).
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
openPrivacyPane("Privacy_ScreenCapture")
log.Warn("Screen Recording permission not granted. " +
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
}
displayID := cgMainDisplayID()
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
// Probe actual pixel dimensions via a test capture. CGDisplayPixelsWide/High
// returns logical points on Retina, but CGDisplayCreateImage produces native
// pixels (often 2x), so probing the image is the only reliable source.
img, err := c.Capture()
if err != nil {
return nil, fmt.Errorf("probe capture: %w", err)
}
nativeW := img.Rect.Dx()
nativeH := img.Rect.Dy()
c.hasHash = false
if nativeW == 0 || nativeH == 0 {
return nil, errors.New("display dimensions are zero")
}
logicalW := int(cgDisplayPixelsWide(displayID))
logicalH := int(cgDisplayPixelsHigh(displayID))
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
// count 4x, shrinking convert, diff, and wire data proportionally.
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
c.downscale = 2
}
c.w = nativeW / c.downscale
c.h = nativeH / c.downscale
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
return c, nil
}
func retinaDownscaleDisabled() bool {
v := os.Getenv(EnvVNCDisableDownscale)
if v == "" {
return false
}
disabled, err := strconv.ParseBool(v)
if err != nil {
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
return false
}
return disabled
}
// Width returns the screen width.
func (c *CGCapturer) Width() int { return c.w }
// Height returns the screen height.
func (c *CGCapturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *CGCapturer) Capture() (*image.RGBA, error) {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return nil, fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
hash := maphash.Bytes(c.hashSeed, src)
if c.hasHash && hash == c.lastHash {
return nil, errFrameUnchanged
}
c.lastHash = hash
c.hasHash = true
ds := c.downscale
if ds < 1 {
ds = 1
}
outW := w / ds
outH := h / ds
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
bytesPerPixel := bpp / 8
if bytesPerPixel == 4 && ds == 1 {
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
} else if bytesPerPixel == 4 && ds == 2 {
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
} else {
for row := 0; row < outH; row++ {
srcOff := row * ds * bytesPerRow
dstOff := row * img.Stride
for col := 0; col < outW; col++ {
si := srcOff + col*ds*bytesPerPixel
di := dstOff + col*4
img.Pix[di+0] = src[si+2]
img.Pix[di+1] = src[si+1]
img.Pix[di+2] = src[si+0]
img.Pix[di+3] = 0xff
}
}
}
return img, nil
}
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
// destination dimensions (source is 2*outW by 2*outH).
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
workers := runtime.GOMAXPROCS(0)
if workers > outH {
workers = outH
}
if workers < 1 || outH < 32 {
workers = 1
}
convertRows := func(y0, y1 int) {
for row := y0; row < y1; row++ {
srcRow0 := 2 * row * srcStride
srcRow1 := srcRow0 + srcStride
dstOff := row * dstStride
for col := 0; col < outW; col++ {
s0 := srcRow0 + col*8
s1 := srcRow1 + col*8
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
di := dstOff + col*4
dst[di+0] = byte(r)
dst[di+1] = byte(g)
dst[di+2] = byte(b)
dst[di+3] = 0xff
}
}
}
if workers == 1 {
convertRows(0, outH)
return
}
var wg sync.WaitGroup
chunk := (outH + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > outH {
y1 = outH
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
// parallelises across GOMAXPROCS cores for large images.
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
workers := runtime.GOMAXPROCS(0)
if workers > h {
workers = h
}
if workers < 1 || h < 64 {
workers = 1
}
convertRows := func(y0, y1 int) {
rowBytes := w * 4
for row := y0; row < y1; row++ {
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
srcRow := src[row*srcStride : row*srcStride+rowBytes]
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
for i, p := range srcU {
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
}
}
}
if workers == 1 {
convertRows(0, h)
return
}
var wg sync.WaitGroup
chunk := (h + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > h {
y1 = h
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// MacPoller wraps CGCapturer in a continuous capture loop.
type MacPoller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
done chan struct{}
// wake shortens the init-retry backoff when a client is trying to connect,
// so granting Screen Recording mid-session takes effect immediately.
wake chan struct{}
}
// NewMacPoller creates a capturer that continuously grabs the macOS display.
func NewMacPoller() *MacPoller {
p := &MacPoller{
done: make(chan struct{}),
wake: make(chan struct{}, 1),
}
go p.loop()
return p
}
// Wake pokes the init-retry loop so it doesn't wait out the full backoff
// before trying again. Safe to call from any goroutine; extra calls while a
// wake is pending are dropped.
func (p *MacPoller) Wake() {
select {
case p.wake <- struct{}{}:
default:
}
}
// Close stops the capture loop.
func (p *MacPoller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *MacPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *MacPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *MacPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *MacPoller) loop() {
var capturer *CGCapturer
var initFails int
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewCGCapturer()
if err != nil {
initFails++
// Retry forever with backoff: the user may grant Screen
// Recording after the server started, and we need to pick it
// up whenever that happens.
delay := 2 * time.Second
if initFails > 15 {
delay = 30 * time.Second
} else if initFails > 5 {
delay = 10 * time.Second
}
if initFails == 1 || initFails%10 == 0 {
log.Warnf("macOS capturer: %v (attempt %d, retrying every %s)", err, initFails, delay)
} else {
log.Debugf("macOS capturer: %v (attempt %d)", err, initFails)
}
select {
case <-p.done:
return
case <-p.wake:
// Client is trying to connect, retry now.
case <-time.After(delay):
}
continue
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if errors.Is(err, errFrameUnchanged) {
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond):
}
continue
}
if err != nil {
log.Debugf("macOS capture: %v", err)
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}
var _ ScreenCapturer = (*MacPoller)(nil)

View File

@@ -1,99 +0,0 @@
//go:build windows
package server
import (
"errors"
"fmt"
"image"
"github.com/kirides/go-d3d/d3d11"
"github.com/kirides/go-d3d/outputduplication"
)
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
// Provides GPU-accelerated capture with native dirty rect tracking.
// Only works from the interactive user session, not Session 0.
//
// Uses a double-buffer: DXGI writes into img, then we copy to the current
// output buffer and hand it out. Alternating between two output buffers
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
type dxgiCapturer struct {
dup *outputduplication.OutputDuplicator
device *d3d11.ID3D11Device
ctx *d3d11.ID3D11DeviceContext
img *image.RGBA
out [2]*image.RGBA
outIdx int
width int
height int
}
func newDXGICapturer() (*dxgiCapturer, error) {
device, deviceCtx, err := d3d11.NewD3D11Device()
if err != nil {
return nil, fmt.Errorf("create D3D11 device: %w", err)
}
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
if err != nil {
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("create output duplication: %w", err)
}
w, h := screenSize()
if w == 0 || h == 0 {
dup.Release()
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("screen dimensions are zero")
}
rect := image.Rect(0, 0, w, h)
c := &dxgiCapturer{
dup: dup,
device: device,
ctx: deviceCtx,
img: image.NewRGBA(rect),
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
width: w,
height: h,
}
// Grab the initial frame with a longer timeout to ensure we have
// a valid image before returning.
_ = dup.GetImage(c.img, 2000)
return c, nil
}
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
err := c.dup.GetImage(c.img, 100)
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
return nil, err
}
// Copy into the next output buffer. The DesktopCapturer hands out the
// returned pointer to VNC sessions that read pixels concurrently, so we
// alternate between two pre-allocated buffers instead of allocating per frame.
out := c.out[c.outIdx]
c.outIdx ^= 1
copy(out.Pix, c.img.Pix)
return out, nil
}
func (c *dxgiCapturer) close() {
if c.dup != nil {
c.dup.Release()
c.dup = nil
}
if c.ctx != nil {
c.ctx.Release()
c.ctx = nil
}
if c.device != nil {
c.device.Release()
c.device = nil
}
}

View File

@@ -1,461 +0,0 @@
//go:build windows
package server
import (
"fmt"
"image"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
user32 = windows.NewLazySystemDLL("user32.dll")
procGetDC = user32.NewProc("GetDC")
procReleaseDC = user32.NewProc("ReleaseDC")
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
procSelectObject = gdi32.NewProc("SelectObject")
procDeleteObject = gdi32.NewProc("DeleteObject")
procDeleteDC = gdi32.NewProc("DeleteDC")
procBitBlt = gdi32.NewProc("BitBlt")
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
// Desktop switching for service/Session 0 capture.
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
procCloseDesktop = user32.NewProc("CloseDesktop")
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
procCloseWindowStation = user32.NewProc("CloseWindowStation")
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
)
const uoiName = 2
const (
smCxScreen = 0
smCyScreen = 1
srccopy = 0x00CC0020
dibRgbColors = 0
)
type bitmapInfoHeader struct {
Size uint32
Width int32
Height int32
Planes uint16
BitCount uint16
Compression uint32
SizeImage uint32
XPelsPerMeter int32
YPelsPerMeter int32
ClrUsed uint32
ClrImportant uint32
}
type bitmapInfo struct {
Header bitmapInfoHeader
}
// setupInteractiveWindowStation associates the current process with WinSta0,
// the interactive window station. This is required for a SYSTEM service in
// Session 0 to call OpenInputDesktop for screen capture and input injection.
func setupInteractiveWindowStation() error {
name, err := windows.UTF16PtrFromString("WinSta0")
if err != nil {
return fmt.Errorf("UTF16 WinSta0: %w", err)
}
hWinSta, _, err := procOpenWindowStation.Call(
uintptr(unsafe.Pointer(name)),
0,
uintptr(windows.MAXIMUM_ALLOWED),
)
if hWinSta == 0 {
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
}
r, _, err := procSetProcessWindowStation.Call(hWinSta)
if r == 0 {
procCloseWindowStation.Call(hWinSta)
return fmt.Errorf("SetProcessWindowStation: %w", err)
}
log.Info("process window station set to WinSta0 (interactive)")
return nil
}
func screenSize() (int, int) {
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
return int(w), int(h)
}
func getDesktopName(hDesk uintptr) string {
var buf [256]uint16
var needed uint32
procGetUserObjectInformationW.Call(hDesk, uoiName,
uintptr(unsafe.Pointer(&buf[0])), 512,
uintptr(unsafe.Pointer(&needed)))
return windows.UTF16ToString(buf[:])
}
// switchToInputDesktop opens the desktop currently receiving user input
// and sets it as the calling OS thread's desktop. Must be called from a
// goroutine locked to its OS thread via runtime.LockOSThread().
func switchToInputDesktop() (bool, string) {
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
if hDesk == 0 {
return false, ""
}
name := getDesktopName(hDesk)
ret, _, _ := procSetThreadDesktop.Call(hDesk)
procCloseDesktop.Call(hDesk)
return ret != 0, name
}
// gdiCapturer captures the desktop screen using GDI BitBlt.
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
type gdiCapturer struct {
mu sync.Mutex
width int
height int
// Pre-allocated GDI resources, reused across captures.
memDC uintptr
bmp uintptr
bits uintptr
}
func newGDICapturer() (*gdiCapturer, error) {
w, h := screenSize()
if w == 0 || h == 0 {
return nil, fmt.Errorf("screen dimensions are zero")
}
c := &gdiCapturer{width: w, height: h}
if err := c.allocGDI(); err != nil {
return nil, err
}
return c, nil
}
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
func (c *gdiCapturer) allocGDI() error {
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
memDC, _, _ := procCreateCompatDC.Call(screenDC)
if memDC == 0 {
return fmt.Errorf("CreateCompatibleDC returned 0")
}
bi := bitmapInfo{
Header: bitmapInfoHeader{
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
Width: int32(c.width),
Height: -int32(c.height), // negative = top-down DIB
Planes: 1,
BitCount: 32,
},
}
var bits uintptr
bmp, _, _ := procCreateDIBSection.Call(
screenDC,
uintptr(unsafe.Pointer(&bi)),
dibRgbColors,
uintptr(unsafe.Pointer(&bits)),
0, 0,
)
if bmp == 0 || bits == 0 {
procDeleteDC.Call(memDC)
return fmt.Errorf("CreateDIBSection returned 0")
}
procSelectObject.Call(memDC, bmp)
c.memDC = memDC
c.bmp = bmp
c.bits = bits
return nil
}
func (c *gdiCapturer) close() { c.freeGDI() }
// freeGDI releases pre-allocated GDI resources.
func (c *gdiCapturer) freeGDI() {
if c.bmp != 0 {
procDeleteObject.Call(c.bmp)
c.bmp = 0
}
if c.memDC != 0 {
procDeleteDC.Call(c.memDC)
c.memDC = 0
}
c.bits = 0
}
func (c *gdiCapturer) capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.memDC == 0 {
return nil, fmt.Errorf("GDI resources not allocated")
}
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return nil, fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
screenDC, 0, 0, srccopy)
if ret == 0 {
return nil, fmt.Errorf("BitBlt returned 0")
}
n := c.width * c.height * 4
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
// Swap R and B in bulk using uint32 operations (one load + mask + shift
// per pixel instead of three separate byte assignments).
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
pix := img.Pix
copy(pix, raw)
swizzleBGRAtoRGBA(pix)
return img, nil
}
// DesktopCapturer captures the interactive desktop, handling desktop transitions
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
// captures frames, which are retrieved by the VNC session on demand.
// Capture pauses automatically when no clients are connected.
type DesktopCapturer struct {
mu sync.Mutex
frame *image.RGBA
w, h int
// clients tracks the number of active VNC sessions. When zero, the
// capture loop idles instead of grabbing frames.
clients atomic.Int32
// wake is signaled when a client connects and the loop should resume.
wake chan struct{}
// done is closed when Close is called, terminating the capture loop.
done chan struct{}
}
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
func NewDesktopCapturer() *DesktopCapturer {
c := &DesktopCapturer{
wake: make(chan struct{}, 1),
done: make(chan struct{}),
}
go c.loop()
return c
}
// ClientConnect increments the active client count, resuming capture if needed.
func (c *DesktopCapturer) ClientConnect() {
c.clients.Add(1)
select {
case c.wake <- struct{}{}:
default:
}
}
// ClientDisconnect decrements the active client count.
func (c *DesktopCapturer) ClientDisconnect() {
c.clients.Add(-1)
}
// Close stops the capture loop and releases resources.
func (c *DesktopCapturer) Close() {
select {
case <-c.done:
default:
close(c.done)
}
}
// Width returns the current screen width.
func (c *DesktopCapturer) Width() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.w
}
// Height returns the current screen height.
func (c *DesktopCapturer) Height() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.h
}
// Capture returns the most recent desktop frame.
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
img := c.frame
c.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
// waitForClient blocks until a client connects or the capturer is closed.
func (c *DesktopCapturer) waitForClient() bool {
if c.clients.Load() > 0 {
return true
}
select {
case <-c.wake:
return true
case <-c.done:
return false
}
}
func (c *DesktopCapturer) loop() {
runtime.LockOSThread()
// When running as a Windows service (Session 0), we need to attach to the
// interactive window station before OpenInputDesktop will succeed.
if err := setupInteractiveWindowStation(); err != nil {
log.Warnf("attach to interactive window station: %v", err)
}
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
defer frameTicker.Stop()
retryTimer := time.NewTimer(0)
retryTimer.Stop()
defer retryTimer.Stop()
type frameCapturer interface {
capture() (*image.RGBA, error)
close()
}
var cap frameCapturer
var desktopFails int
var lastDesktop string
createCapturer := func() (frameCapturer, error) {
dc, err := newDXGICapturer()
if err == nil {
log.Info("using DXGI Desktop Duplication for capture")
return dc, nil
}
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
gc, err := newGDICapturer()
if err != nil {
return nil, err
}
log.Info("using GDI BitBlt for capture")
return gc, nil
}
for {
if !c.waitForClient() {
if cap != nil {
cap.close()
}
return
}
// No clients: release the capturer and wait.
if c.clients.Load() <= 0 {
if cap != nil {
cap.close()
cap = nil
}
continue
}
ok, desk := switchToInputDesktop()
if !ok {
desktopFails++
if desktopFails == 1 || desktopFails%100 == 0 {
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
}
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
if desktopFails > 0 {
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
desktopFails = 0
}
if desk != lastDesktop {
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
lastDesktop = desk
if cap != nil {
cap.close()
}
cap = nil
}
if cap == nil {
fc, err := createCapturer()
if err != nil {
log.Warnf("create capturer: %v", err)
retryTimer.Reset(500 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
cap = fc
w, h := screenSize()
c.mu.Lock()
c.w, c.h = w, h
c.mu.Unlock()
log.Infof("screen capturer ready: %dx%d", w, h)
}
img, err := cap.capture()
if err != nil {
log.Debugf("capture: %v", err)
cap.close()
cap = nil
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
c.mu.Lock()
c.frame = img
c.mu.Unlock()
select {
case <-frameTicker.C:
case <-c.done:
if cap != nil {
cap.close()
}
return
}
}
}

View File

@@ -1,385 +0,0 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"image"
"os"
"os/exec"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
)
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
type X11Capturer struct {
mu sync.Mutex
conn *xgb.Conn
screen *xproto.ScreenInfo
w, h int
shmID int
shmAddr []byte
shmSeg uint32 // shm.Seg
useSHM bool
}
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
// environment variables if needed. This is required when running as a system
// service where these vars aren't set.
func detectX11Display() {
if os.Getenv("DISPLAY") != "" {
return
}
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
if detectX11FromProc() {
return
}
if detectX11FromSockets() {
return
}
}
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
func detectX11FromProc() bool {
entries, err := os.ReadDir("/proc")
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() {
continue
}
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
if err != nil {
continue
}
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
setDisplayEnv(display, auth)
return true
}
}
return false
}
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
// to find the auth file. Works on FreeBSD and other systems without /proc.
func detectX11FromSockets() bool {
entries, err := os.ReadDir("/tmp/.X11-unix")
if err != nil {
return false
}
// Find the lowest display number.
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
display := ":" + name[1:]
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
// Try to find -auth from ps output.
if auth := findXorgAuthFromPS(); auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
}
return true
}
return false
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
func findXorgAuthFromPS() string {
out, err := exec.Command("ps", "auxww").Output()
if err != nil {
return ""
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
continue
}
fields := strings.Fields(line)
for i, f := range fields {
if f == "-auth" && i+1 < len(fields) {
return fields[i+1]
}
}
}
return ""
}
func parseXorgArgs(args []string) (display, auth string) {
if len(args) == 0 {
return "", ""
}
base := args[0]
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
return "", ""
}
for i, arg := range args[1:] {
if len(arg) > 0 && arg[0] == ':' {
display = arg
}
if arg == "-auth" && i+2 < len(args) {
auth = args[i+2]
}
}
return display, auth
}
func setDisplayEnv(display, auth string) {
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s", display)
if auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s", auth)
}
}
func splitCmdline(data []byte) []string {
var args []string
for _, b := range splitNull(data) {
if len(b) > 0 {
args = append(args, string(b))
}
}
return args
}
func splitNull(data []byte) [][]byte {
var parts [][]byte
start := 0
for i, b := range data {
if b == 0 {
parts = append(parts, data[start:i])
start = i + 1
}
}
if start < len(data) {
parts = append(parts, data[start:])
}
return parts
}
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
func NewX11Capturer(display string) (*X11Capturer, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
c := &X11Capturer{
conn: conn,
screen: &screen,
w: int(screen.WidthInPixels),
h: int(screen.HeightInPixels),
}
if err := c.initSHM(); err != nil {
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
}
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
return c, nil
}
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
// the capturer falls back to GetImage.
// Width returns the screen width.
func (c *X11Capturer) Width() int { return c.w }
// Height returns the screen height.
func (c *X11Capturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *X11Capturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.useSHM {
return c.captureSHM()
}
return c.captureGetImage()
}
// captureSHM is implemented in capture_x11_shm_linux.go.
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
data := reply.Data
n := c.w * c.h * 4
if len(data) < n {
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
}
for i := 0; i < n; i += 4 {
img.Pix[i+0] = data[i+2] // R
img.Pix[i+1] = data[i+1] // G
img.Pix[i+2] = data[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
// Close releases X11 resources.
func (c *X11Capturer) Close() {
c.closeSHM()
c.conn.Close()
}
// closeSHM is implemented in capture_x11_shm_linux.go.
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
// DesktopCapturer pattern from Windows.
type X11Poller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
display string
done chan struct{}
}
// NewX11Poller creates a capturer that continuously grabs the X11 display.
func NewX11Poller(display string) *X11Poller {
p := &X11Poller{
display: display,
done: make(chan struct{}),
}
go p.loop()
return p
}
// Close stops the capture loop.
func (p *X11Poller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *X11Poller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *X11Poller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *X11Poller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *X11Poller) loop() {
var capturer *X11Capturer
var initFails int
defer func() {
if capturer != nil {
capturer.Close()
}
}()
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewX11Capturer(p.display)
if err != nil {
initFails++
if initFails <= maxCapturerRetries {
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
select {
case <-p.done:
return
case <-time.After(2 * time.Second):
}
continue
}
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
return
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if err != nil {
log.Debugf("X11 capture: %v", err)
capturer.Close()
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}

View File

@@ -1,78 +0,0 @@
//go:build linux && !android
package server
import (
"fmt"
"image"
"github.com/jezek/xgb/shm"
"github.com/jezek/xgb/xproto"
"golang.org/x/sys/unix"
)
func (c *X11Capturer) initSHM() error {
if err := shm.Init(c.conn); err != nil {
return fmt.Errorf("init SHM extension: %w", err)
}
size := c.w * c.h * 4
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
if err != nil {
return fmt.Errorf("shmget: %w", err)
}
addr, err := unix.SysvShmAttach(id, 0, 0)
if err != nil {
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
return fmt.Errorf("shmat: %w", err)
}
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
seg, err := shm.NewSegId(c.conn)
if err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("new SHM seg: %w", err)
}
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("SHM attach to X: %w", err)
}
c.shmID = id
c.shmAddr = addr
c.shmSeg = uint32(seg)
c.useSHM = true
return nil
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
_, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("SHM GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
n := c.w * c.h * 4
for i := 0; i < n; i += 4 {
img.Pix[i+0] = c.shmAddr[i+2] // R
img.Pix[i+1] = c.shmAddr[i+1] // G
img.Pix[i+2] = c.shmAddr[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
func (c *X11Capturer) closeSHM() {
if c.useSHM {
shm.Detach(c.conn, shm.Seg(c.shmSeg))
unix.SysvShmDetach(c.shmAddr)
}
}

View File

@@ -1,18 +0,0 @@
//go:build freebsd
package server
import (
"fmt"
"image"
)
func (c *X11Capturer) initSHM() error {
return fmt.Errorf("SysV SHM not available on this platform")
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
return nil, fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) closeSHM() {}

View File

@@ -1,151 +0,0 @@
package server
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"crypto/sha256"
"golang.org/x/crypto/hkdf"
)
const (
aesKeySize = 32 // AES-256
gcmNonceSize = 12
)
// recCrypto holds per-session encryption state.
type recCrypto struct {
gcm cipher.AEAD
frameCounter uint64
// ephemeralPub is stored in the recording header so the admin can derive the same key.
ephemeralPub []byte
}
// newRecCrypto sets up encryption for a new recording session.
// adminPubKeyB64 is the base64-encoded X25519 public key from management settings.
func newRecCrypto(adminPubKeyB64 string) (*recCrypto, error) {
adminPubBytes, err := base64.StdEncoding.DecodeString(adminPubKeyB64)
if err != nil {
return nil, fmt.Errorf("decode admin public key: %w", err)
}
adminPub, err := ecdh.X25519().NewPublicKey(adminPubBytes)
if err != nil {
return nil, fmt.Errorf("parse admin X25519 public key: %w", err)
}
// Generate ephemeral keypair
ephemeral, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate ephemeral key: %w", err)
}
// ECDH shared secret
shared, err := ephemeral.ECDH(adminPub)
if err != nil {
return nil, fmt.Errorf("ECDH: %w", err)
}
// Derive AES-256 key via HKDF
aesKey, err := deriveKey(shared, ephemeral.PublicKey().Bytes())
if err != nil {
return nil, fmt.Errorf("derive key: %w", err)
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create GCM: %w", err)
}
return &recCrypto{
gcm: gcm,
ephemeralPub: ephemeral.PublicKey().Bytes(),
}, nil
}
// encrypt encrypts plaintext using a counter-based nonce. Each call increments the counter.
func (c *recCrypto) encrypt(plaintext []byte) []byte {
nonce := make([]byte, gcmNonceSize)
binary.LittleEndian.PutUint64(nonce, c.frameCounter)
c.frameCounter++
return c.gcm.Seal(nil, nonce, plaintext, nil)
}
// DecryptRecording creates a decryptor from the admin's private key and the ephemeral public key from the header.
func DecryptRecording(adminPrivKeyB64 string, ephemeralPubB64 string) (*recDecryptor, error) {
adminPrivBytes, err := base64.StdEncoding.DecodeString(adminPrivKeyB64)
if err != nil {
return nil, fmt.Errorf("decode admin private key: %w", err)
}
adminPriv, err := ecdh.X25519().NewPrivateKey(adminPrivBytes)
if err != nil {
return nil, fmt.Errorf("parse admin X25519 private key: %w", err)
}
ephPubBytes, err := base64.StdEncoding.DecodeString(ephemeralPubB64)
if err != nil {
return nil, fmt.Errorf("decode ephemeral public key: %w", err)
}
ephPub, err := ecdh.X25519().NewPublicKey(ephPubBytes)
if err != nil {
return nil, fmt.Errorf("parse ephemeral public key: %w", err)
}
shared, err := adminPriv.ECDH(ephPub)
if err != nil {
return nil, fmt.Errorf("ECDH: %w", err)
}
aesKey, err := deriveKey(shared, ephPubBytes)
if err != nil {
return nil, fmt.Errorf("derive key: %w", err)
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create GCM: %w", err)
}
return &recDecryptor{gcm: gcm}, nil
}
type recDecryptor struct {
gcm cipher.AEAD
frameCounter uint64
}
// Decrypt decrypts a frame. Must be called in the same order as encryption.
func (d *recDecryptor) Decrypt(ciphertext []byte) ([]byte, error) {
nonce := make([]byte, gcmNonceSize)
binary.LittleEndian.PutUint64(nonce, d.frameCounter)
d.frameCounter++
return d.gcm.Open(nil, nonce, ciphertext, nil)
}
func deriveKey(shared, ephemeralPub []byte) ([]byte, error) {
hkdfReader := hkdf.New(sha256.New, shared, ephemeralPub, []byte("netbird-recording"))
key := make([]byte, aesKeySize)
if _, err := io.ReadFull(hkdfReader, key); err != nil {
return nil, err
}
return key, nil
}

View File

@@ -1,129 +0,0 @@
package server
import (
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCryptoRoundtrip(t *testing.T) {
// Generate admin keypair
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
// Create encryptor (recording side)
enc, err := newRecCrypto(adminPubB64)
require.NoError(t, err)
assert.Len(t, enc.ephemeralPub, 32)
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
// Encrypt some frames
plaintext1 := []byte("frame data one - PNG bytes would go here")
plaintext2 := []byte("frame data two - different content")
plaintext3 := make([]byte, 1024*100) // 100KB frame
rand.Read(plaintext3)
ct1 := enc.encrypt(plaintext1)
ct2 := enc.encrypt(plaintext2)
ct3 := enc.encrypt(plaintext3)
// Ciphertext should differ from plaintext
assert.NotEqual(t, plaintext1, ct1)
// Ciphertext is larger (GCM tag overhead)
assert.Greater(t, len(ct1), len(plaintext1))
// Create decryptor (playback side)
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
require.NoError(t, err)
// Decrypt in same order
got1, err := dec.Decrypt(ct1)
require.NoError(t, err)
assert.Equal(t, plaintext1, got1)
got2, err := dec.Decrypt(ct2)
require.NoError(t, err)
assert.Equal(t, plaintext2, got2)
got3, err := dec.Decrypt(ct3)
require.NoError(t, err)
assert.Equal(t, plaintext3, got3)
}
func TestCryptoWrongKey(t *testing.T) {
// Admin key
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
// Encrypt with admin's public key
enc, err := newRecCrypto(adminPubB64)
require.NoError(t, err)
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
ct := enc.encrypt([]byte("secret frame data"))
// Try to decrypt with a different private key
wrongPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
wrongPrivB64 := base64.StdEncoding.EncodeToString(wrongPriv.Bytes())
dec, err := DecryptRecording(wrongPrivB64, ephPubB64)
require.NoError(t, err)
_, err = dec.Decrypt(ct)
assert.Error(t, err, "decryption with wrong key should fail")
}
func TestCryptoInvalidKey(t *testing.T) {
_, err := newRecCrypto("")
assert.Error(t, err, "empty key should fail")
_, err = newRecCrypto("not-base64!!!")
assert.Error(t, err, "invalid base64 should fail")
_, err = newRecCrypto(base64.StdEncoding.EncodeToString([]byte("too-short")))
assert.Error(t, err, "wrong-length key should fail")
_, err = DecryptRecording("", "validbutirrelevant")
assert.Error(t, err, "empty private key should fail")
_, err = DecryptRecording("not-base64!!!", base64.StdEncoding.EncodeToString(make([]byte, 32)))
assert.Error(t, err, "invalid base64 private key should fail")
}
func TestCryptoOutOfOrderFails(t *testing.T) {
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
enc, err := newRecCrypto(adminPubB64)
require.NoError(t, err)
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
ct0 := enc.encrypt([]byte("frame 0"))
ct1 := enc.encrypt([]byte("frame 1"))
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
require.NoError(t, err)
// Skip frame 0, try to decrypt frame 1 first (wrong nonce)
_, err = dec.Decrypt(ct1)
assert.Error(t, err, "out-of-order decryption should fail due to nonce mismatch")
// But frame 0 with a fresh decryptor should work
dec2, err := DecryptRecording(adminPrivB64, ephPubB64)
require.NoError(t, err)
got, err := dec2.Decrypt(ct0)
require.NoError(t, err)
assert.Equal(t, []byte("frame 0"), got)
}

View File

@@ -1,540 +0,0 @@
//go:build darwin && !ios
package server
import (
"fmt"
"os/exec"
"strings"
"sync"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
// Core Graphics event constants.
const (
kCGEventSourceStateCombinedSessionState int32 = 0
kCGEventLeftMouseDown int32 = 1
kCGEventLeftMouseUp int32 = 2
kCGEventRightMouseDown int32 = 3
kCGEventRightMouseUp int32 = 4
kCGEventMouseMoved int32 = 5
kCGEventLeftMouseDragged int32 = 6
kCGEventRightMouseDragged int32 = 7
kCGEventKeyDown int32 = 10
kCGEventKeyUp int32 = 11
kCGEventOtherMouseDown int32 = 25
kCGEventOtherMouseUp int32 = 26
kCGMouseButtonLeft int32 = 0
kCGMouseButtonRight int32 = 1
kCGMouseButtonCenter int32 = 2
kCGHIDEventTap int32 = 0
// IOKit power management constants.
kIOPMUserActiveLocal int32 = 0
kIOPMAssertionLevelOn uint32 = 255
kCFStringEncodingUTF8 uint32 = 0x08000100
)
var darwinInputOnce sync.Once
var (
cgEventSourceCreate func(int32) uintptr
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
// purego can't handle array/struct types but individual float64s work.
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
cgEventPost func(int32, uintptr)
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
cgEventCreateScrollWheelEventAddr uintptr
axIsProcessTrusted func() bool
// IOKit power-management bindings used to wake the display and inhibit
// idle sleep while a VNC client is driving input.
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
iopmAssertionRelease func(uint32) int32
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
// Cached CFStrings for assertion name and idle-sleep type.
pmAssertionNameCFStr uintptr
pmPreventIdleDisplayCFStr uintptr
// Assertion IDs. userActivityID is reused across input events so repeated
// calls refresh the same assertion rather than create new ones.
pmMu sync.Mutex
userActivityID uint32
preventSleepID uint32
preventSleepHeld bool
darwinInputReady bool
darwinEventSource uintptr
)
func initDarwinInput() {
darwinInputOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics for input: %v", err)
return
}
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
if err == nil {
cgEventCreateScrollWheelEventAddr = sym
}
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
purego.RegisterFunc(&axIsProcessTrusted, sym)
}
}
initPowerAssertions()
darwinInputReady = true
})
}
func initPowerAssertions() {
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load IOKit: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation for power assertions: %v", err)
return
}
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
}
// wakeDisplay declares user activity so macOS treats the synthesized input as
// real HID activity, waking the display if it is asleep. Called on every key
// and pointer event; the kernel coalesces repeated calls cheaply.
func wakeDisplay() {
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
return
}
pmMu.Lock()
id := userActivityID
pmMu.Unlock()
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
if r != 0 {
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
return
}
pmMu.Lock()
userActivityID = id
pmMu.Unlock()
}
// holdPreventIdleSleep creates an assertion that keeps the display from going
// idle-to-sleep while a VNC session is active. Safe to call repeatedly.
func holdPreventIdleSleep() {
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
return
}
pmMu.Lock()
defer pmMu.Unlock()
if preventSleepHeld {
return
}
var id uint32
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
if r != 0 {
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
return
}
preventSleepID = id
preventSleepHeld = true
}
// releasePreventIdleSleep drops the idle-sleep assertion.
func releasePreventIdleSleep() {
if iopmAssertionRelease == nil {
return
}
pmMu.Lock()
defer pmMu.Unlock()
if !preventSleepHeld {
return
}
if r := iopmAssertionRelease(preventSleepID); r != 0 {
log.Debugf("IOPMAssertionRelease returned %d", r)
}
preventSleepHeld = false
preventSleepID = 0
}
func ensureEventSource() uintptr {
if darwinEventSource != 0 {
return darwinEventSource
}
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
return darwinEventSource
}
// MacInputInjector injects keyboard and mouse events via Core Graphics.
type MacInputInjector struct {
lastButtons uint8
pbcopyPath string
pbpastePath string
}
// NewMacInputInjector creates a macOS input injector.
func NewMacInputInjector() (*MacInputInjector, error) {
initDarwinInput()
if !darwinInputReady {
return nil, fmt.Errorf("CoreGraphics not available for input injection")
}
checkMacPermissions()
m := &MacInputInjector{}
if path, err := exec.LookPath("pbcopy"); err == nil {
m.pbcopyPath = path
}
if path, err := exec.LookPath("pbpaste"); err == nil {
m.pbpastePath = path
}
if m.pbcopyPath == "" || m.pbpastePath == "" {
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
}
holdPreventIdleSleep()
log.Info("macOS input injector ready")
return m, nil
}
// checkMacPermissions warns and opens the Privacy pane if Accessibility is
// missing. Uses AXIsProcessTrusted which returns immediately; the previous
// osascript probe blocked for 120s (AppleEvent timeout) when access was
// denied, which delayed VNC server startup past client deadlines.
func checkMacPermissions() {
if axIsProcessTrusted != nil && !axIsProcessTrusted() {
openPrivacyPane("Privacy_Accessibility")
log.Warn("Accessibility permission not granted. Input injection will not work. " +
"Opened System Settings > Privacy & Security > Accessibility; enable netbird.")
}
log.Info("Screen Recording permission is required for screen capture. " +
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
}
// openPrivacyPane opens the given Privacy pane in System Settings so the user
// can toggle the permission without navigating manually.
func openPrivacyPane(pane string) {
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
if err := exec.Command("open", url).Start(); err != nil {
log.Debugf("open privacy pane %s: %v", pane, err)
}
}
// InjectKey simulates a key press or release.
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
wakeDisplay()
src := ensureEventSource()
if src == 0 {
return
}
keycode := keysymToMacKeycode(keysym)
if keycode == 0xFFFF {
return
}
event := cgEventCreateKeyboardEvent(src, keycode, down)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
// InjectPointer simulates mouse movement and button events.
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
wakeDisplay()
if serverW == 0 || serverH == 0 {
return
}
src := ensureEventSource()
if src == 0 {
return
}
// Framebuffer is in physical pixels (Retina). CGEventCreateMouseEvent
// expects logical points, so scale down by the display's pixel/point ratio.
x := float64(px)
y := float64(py)
if cgDisplayPixelsWide != nil && cgMainDisplayID != nil {
displayID := cgMainDisplayID()
logicalW := int(cgDisplayPixelsWide(displayID))
logicalH := int(cgDisplayPixelsHigh(displayID))
if logicalW > 0 && logicalH > 0 {
x = float64(px) * float64(logicalW) / float64(serverW)
y = float64(py) * float64(logicalH) / float64(serverH)
}
}
leftDown := buttonMask&0x01 != 0
rightDown := buttonMask&0x04 != 0
middleDown := buttonMask&0x02 != 0
scrollUp := buttonMask&0x08 != 0
scrollDown := buttonMask&0x10 != 0
wasLeft := m.lastButtons&0x01 != 0
wasRight := m.lastButtons&0x04 != 0
wasMiddle := m.lastButtons&0x02 != 0
if leftDown {
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
} else if rightDown {
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
} else {
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
}
if leftDown && !wasLeft {
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
} else if !leftDown && wasLeft {
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
}
if rightDown && !wasRight {
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
} else if !rightDown && wasRight {
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
}
if middleDown && !wasMiddle {
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
} else if !middleDown && wasMiddle {
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
}
if scrollUp {
m.postScroll(src, 3)
}
if scrollDown {
m.postScroll(src, -3)
}
m.lastButtons = buttonMask
}
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
if cgEventCreateMouseEvent == nil {
return
}
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
if cgEventCreateScrollWheelEventAddr == 0 {
return
}
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
// Variadic C function: pass args as uintptr via SyscallN.
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
src, 0, 1, uintptr(uint32(deltaY)))
if r1 == 0 {
return
}
cgEventPost(kCGHIDEventTap, r1)
cfRelease(r1)
}
// SetClipboard sets the macOS clipboard using pbcopy.
func (m *MacInputInjector) SetClipboard(text string) {
if m.pbcopyPath == "" {
return
}
cmd := exec.Command(m.pbcopyPath)
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Tracef("set clipboard via pbcopy: %v", err)
}
}
// GetClipboard reads the macOS clipboard using pbpaste.
func (m *MacInputInjector) GetClipboard() string {
if m.pbpastePath == "" {
return ""
}
out, err := exec.Command(m.pbpastePath).Output()
if err != nil {
log.Tracef("get clipboard via pbpaste: %v", err)
return ""
}
return string(out)
}
// Close releases the idle-sleep assertion held for the injector's lifetime.
func (m *MacInputInjector) Close() {
releasePreventIdleSleep()
}
func keysymToMacKeycode(keysym uint32) uint16 {
if keysym >= 0x61 && keysym <= 0x7a {
return asciiToMacKey[keysym-0x61]
}
if keysym >= 0x41 && keysym <= 0x5a {
return asciiToMacKey[keysym-0x41]
}
if keysym >= 0x30 && keysym <= 0x39 {
return digitToMacKey[keysym-0x30]
}
if code, ok := specialKeyMap[keysym]; ok {
return code
}
return 0xFFFF
}
var asciiToMacKey = [26]uint16{
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
0x10, 0x06,
}
var digitToMacKey = [10]uint16{
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
}
var specialKeyMap = map[uint32]uint16{
// Whitespace and editing
0x0020: 0x31, // space
0xff08: 0x33, // BackSpace
0xff09: 0x30, // Tab
0xff0d: 0x24, // Return
0xff1b: 0x35, // Escape
0xffff: 0x75, // Delete (forward)
// Navigation
0xff50: 0x73, // Home
0xff51: 0x7B, // Left
0xff52: 0x7E, // Up
0xff53: 0x7C, // Right
0xff54: 0x7D, // Down
0xff55: 0x74, // Page_Up
0xff56: 0x79, // Page_Down
0xff57: 0x77, // End
0xff63: 0x72, // Insert (Help on Mac)
// Modifiers
0xffe1: 0x38, // Shift_L
0xffe2: 0x3C, // Shift_R
0xffe3: 0x3B, // Control_L
0xffe4: 0x3E, // Control_R
0xffe5: 0x39, // Caps_Lock
0xffe9: 0x3A, // Alt_L (Option)
0xffea: 0x3D, // Alt_R (Option)
0xffe7: 0x37, // Meta_L (Command)
0xffe8: 0x36, // Meta_R (Command)
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
0xffec: 0x36, // Super_R (Command)
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
0xff7e: 0x3A, // Mode_switch -> Option
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
// Function keys
0xffbe: 0x7A, // F1
0xffbf: 0x78, // F2
0xffc0: 0x63, // F3
0xffc1: 0x76, // F4
0xffc2: 0x60, // F5
0xffc3: 0x61, // F6
0xffc4: 0x62, // F7
0xffc5: 0x64, // F8
0xffc6: 0x65, // F9
0xffc7: 0x6D, // F10
0xffc8: 0x67, // F11
0xffc9: 0x6F, // F12
0xffca: 0x69, // F13
0xffcb: 0x6B, // F14
0xffcc: 0x71, // F15
0xffcd: 0x6A, // F16
0xffce: 0x40, // F17
0xffcf: 0x4F, // F18
0xffd0: 0x50, // F19
0xffd1: 0x5A, // F20
// Punctuation (US keyboard layout, keysym = ASCII code)
0x002d: 0x1B, // minus -
0x003d: 0x18, // equal =
0x005b: 0x21, // bracketleft [
0x005d: 0x1E, // bracketright ]
0x005c: 0x2A, // backslash
0x003b: 0x29, // semicolon ;
0x0027: 0x27, // apostrophe '
0x0060: 0x32, // grave `
0x002c: 0x2B, // comma ,
0x002e: 0x2F, // period .
0x002f: 0x2C, // slash /
// Shifted punctuation (noVNC sends these as separate keysyms)
0x005f: 0x1B, // underscore _ (shift+minus)
0x002b: 0x18, // plus + (shift+equal)
0x007b: 0x21, // braceleft { (shift+[)
0x007d: 0x1E, // braceright } (shift+])
0x007c: 0x2A, // bar | (shift+\)
0x003a: 0x29, // colon : (shift+;)
0x0022: 0x27, // quotedbl " (shift+')
0x007e: 0x32, // tilde ~ (shift+`)
0x003c: 0x2B, // less < (shift+,)
0x003e: 0x2F, // greater > (shift+.)
0x003f: 0x2C, // question ? (shift+/)
0x0021: 0x12, // exclam ! (shift+1)
0x0040: 0x13, // at @ (shift+2)
0x0023: 0x14, // numbersign # (shift+3)
0x0024: 0x15, // dollar $ (shift+4)
0x0025: 0x17, // percent % (shift+5)
0x005e: 0x16, // asciicircum ^ (shift+6)
0x0026: 0x1A, // ampersand & (shift+7)
0x002a: 0x1C, // asterisk * (shift+8)
0x0028: 0x19, // parenleft ( (shift+9)
0x0029: 0x1D, // parenright ) (shift+0)
// Numpad
0xffb0: 0x52, // KP_0
0xffb1: 0x53, // KP_1
0xffb2: 0x54, // KP_2
0xffb3: 0x55, // KP_3
0xffb4: 0x56, // KP_4
0xffb5: 0x57, // KP_5
0xffb6: 0x58, // KP_6
0xffb7: 0x59, // KP_7
0xffb8: 0x5B, // KP_8
0xffb9: 0x5C, // KP_9
0xffae: 0x41, // KP_Decimal
0xffaa: 0x43, // KP_Multiply
0xffab: 0x45, // KP_Add
0xffad: 0x4E, // KP_Subtract
0xffaf: 0x4B, // KP_Divide
0xff8d: 0x4C, // KP_Enter
0xffbd: 0x51, // KP_Equal
}
var _ InputInjector = (*MacInputInjector)(nil)

View File

@@ -1,398 +0,0 @@
//go:build windows
package server
import (
"runtime"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
procOpenEventW = kernel32.NewProc("OpenEventW")
procSendInput = user32.NewProc("SendInput")
procVkKeyScanA = user32.NewProc("VkKeyScanA")
)
const eventModifyState = 0x0002
const (
inputMouse = 0
inputKeyboard = 1
mouseeventfMove = 0x0001
mouseeventfLeftDown = 0x0002
mouseeventfLeftUp = 0x0004
mouseeventfRightDown = 0x0008
mouseeventfRightUp = 0x0010
mouseeventfMiddleDown = 0x0020
mouseeventfMiddleUp = 0x0040
mouseeventfWheel = 0x0800
mouseeventfAbsolute = 0x8000
wheelDelta = 120
keyeventfKeyUp = 0x0002
keyeventfScanCode = 0x0008
)
type mouseInput struct {
Dx int32
Dy int32
MouseData uint32
DwFlags uint32
Time uint32
DwExtraInfo uintptr
}
type keybdInput struct {
WVk uint16
WScan uint16
DwFlags uint32
Time uint32
DwExtraInfo uintptr
_ [8]byte
}
type inputUnion [32]byte
type winInput struct {
Type uint32
_ [4]byte
Data inputUnion
}
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
mi := mouseInput{
Dx: dx,
Dy: dy,
MouseData: mouseData,
DwFlags: flags,
}
inp := winInput{Type: inputMouse}
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
if r == 0 {
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
}
}
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
ki := keybdInput{
WVk: vk,
WScan: scanCode,
DwFlags: flags,
}
inp := winInput{Type: inputKeyboard}
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
if r == 0 {
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
}
}
const sasEventName = `Global\NetBirdVNC_SAS`
type inputCmd struct {
isKey bool
keysym uint32
down bool
buttonMask uint8
x, y int
serverW int
serverH int
}
// WindowsInputInjector delivers input events from a dedicated OS thread that
// calls switchToInputDesktop before each injection. SendInput targets the
// calling thread's desktop, so the injection thread must be on the same
// desktop the user sees.
type WindowsInputInjector struct {
ch chan inputCmd
prevButtonMask uint8
ctrlDown bool
altDown bool
}
// NewWindowsInputInjector creates a desktop-aware input injector.
func NewWindowsInputInjector() *WindowsInputInjector {
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
go w.loop()
return w
}
func (w *WindowsInputInjector) loop() {
runtime.LockOSThread()
for cmd := range w.ch {
// Switch to the current input desktop so SendInput reaches the right target.
switchToInputDesktop()
if cmd.isKey {
w.doInjectKey(cmd.keysym, cmd.down)
} else {
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
}
}
}
// InjectKey queues a key event for injection on the input desktop thread.
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
}
// InjectPointer queues a pointer event for injection on the input desktop thread.
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
}
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
switch keysym {
case 0xffe3, 0xffe4:
w.ctrlDown = down
case 0xffe9, 0xffea:
w.altDown = down
}
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
signalSAS()
return
}
vk, _, extended := keysym2VK(keysym)
if vk == 0 {
return
}
var flags uint32
if !down {
flags |= keyeventfKeyUp
}
if extended {
flags |= keyeventfScanCode
}
sendKeyInput(vk, 0, flags)
}
// signalSAS signals the SAS named event. A listener in Session 0
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
func signalSAS() {
namePtr, err := windows.UTF16PtrFromString(sasEventName)
if err != nil {
log.Warnf("SAS UTF16: %v", err)
return
}
h, _, lerr := procOpenEventW.Call(
uintptr(eventModifyState),
0,
uintptr(unsafe.Pointer(namePtr)),
)
if h == 0 {
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
return
}
ev := windows.Handle(h)
defer windows.CloseHandle(ev)
if err := windows.SetEvent(ev); err != nil {
log.Warnf("SetEvent SAS: %v", err)
} else {
log.Info("SAS event signaled")
}
}
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
absX := int32(x * 65535 / serverW)
absY := int32(y * 65535 / serverH)
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
changed := buttonMask ^ w.prevButtonMask
w.prevButtonMask = buttonMask
type btnMap struct {
bit uint8
down uint32
up uint32
}
buttons := [...]btnMap{
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
{0x04, mouseeventfRightDown, mouseeventfRightUp},
}
for _, b := range buttons {
if changed&b.bit == 0 {
continue
}
var flags uint32
if buttonMask&b.bit != 0 {
flags = b.down
} else {
flags = b.up
}
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
}
negWheelDelta := ^uint32(wheelDelta - 1)
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
}
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
}
}
// keysym2VK converts an X11 keysym to a Windows virtual key code.
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
if keysym >= 0x20 && keysym <= 0x7e {
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
vk = uint16(r & 0xff)
return
}
if keysym >= 0xffbe && keysym <= 0xffc9 {
vk = uint16(0x70 + keysym - 0xffbe)
return
}
switch keysym {
case 0xff08:
vk = 0x08 // Backspace
case 0xff09:
vk = 0x09 // Tab
case 0xff0d:
vk = 0x0d // Return
case 0xff1b:
vk = 0x1b // Escape
case 0xff63:
vk, extended = 0x2d, true // Insert
case 0xff9f, 0xffff:
vk, extended = 0x2e, true // Delete
case 0xff50:
vk, extended = 0x24, true // Home
case 0xff57:
vk, extended = 0x23, true // End
case 0xff55:
vk, extended = 0x21, true // PageUp
case 0xff56:
vk, extended = 0x22, true // PageDown
case 0xff51:
vk, extended = 0x25, true // Left
case 0xff52:
vk, extended = 0x26, true // Up
case 0xff53:
vk, extended = 0x27, true // Right
case 0xff54:
vk, extended = 0x28, true // Down
case 0xffe1, 0xffe2:
vk = 0x10 // Shift
case 0xffe3, 0xffe4:
vk = 0x11 // Control
case 0xffe9, 0xffea:
vk = 0x12 // Alt
case 0xffe5:
vk = 0x14 // CapsLock
case 0xffe7, 0xffeb:
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
case 0xffe8, 0xffec:
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
case 0xff61:
vk = 0x2c // PrintScreen
case 0xff13:
vk = 0x13 // Pause
case 0xff14:
vk = 0x91 // ScrollLock
}
return
}
var (
procOpenClipboard = user32.NewProc("OpenClipboard")
procCloseClipboard = user32.NewProc("CloseClipboard")
procEmptyClipboard = user32.NewProc("EmptyClipboard")
procSetClipboardData = user32.NewProc("SetClipboardData")
procGetClipboardData = user32.NewProc("GetClipboardData")
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
procGlobalLock = kernel32.NewProc("GlobalLock")
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
)
const (
cfUnicodeText = 13
gmemMoveable = 0x0002
)
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
func (w *WindowsInputInjector) SetClipboard(text string) {
utf16, err := windows.UTF16FromString(text)
if err != nil {
log.Tracef("clipboard UTF16 encode: %v", err)
return
}
size := uintptr(len(utf16) * 2)
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
if hMem == 0 {
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
return
}
ptr, _, _ := procGlobalLock.Call(hMem)
if ptr == 0 {
log.Tracef("GlobalLock for clipboard: lock returned nil")
return
}
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
procGlobalUnlock.Call(hMem)
r, _, lerr := procOpenClipboard.Call(0)
if r == 0 {
log.Tracef("OpenClipboard: %v", lerr)
return
}
defer procCloseClipboard.Call()
procEmptyClipboard.Call()
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
if r == 0 {
log.Tracef("SetClipboardData: %v", lerr)
}
}
// GetClipboard reads the Windows clipboard as UTF-8 text.
func (w *WindowsInputInjector) GetClipboard() string {
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
if r == 0 {
return ""
}
r, _, lerr := procOpenClipboard.Call(0)
if r == 0 {
log.Tracef("OpenClipboard for read: %v", lerr)
return ""
}
defer procCloseClipboard.Call()
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
if hData == 0 {
return ""
}
ptr, _, _ := procGlobalLock.Call(hData)
if ptr == 0 {
return ""
}
defer procGlobalUnlock.Call(hData)
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
}
var _ InputInjector = (*WindowsInputInjector)(nil)
var _ ScreenCapturer = (*DesktopCapturer)(nil)

View File

@@ -1,242 +0,0 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"os"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
"github.com/jezek/xgb/xtest"
)
// X11InputInjector injects keyboard and mouse events via the XTest extension.
type X11InputInjector struct {
conn *xgb.Conn
root xproto.Window
screen *xproto.ScreenInfo
display string
keysymMap map[uint32]byte
lastButtons uint8
clipboardTool string
clipboardToolName string
}
// NewX11InputInjector connects to the X11 display and initializes XTest.
func NewX11InputInjector(display string) (*X11InputInjector, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
if err := xtest.Init(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("init XTest extension: %w", err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
inj := &X11InputInjector{
conn: conn,
root: screen.Root,
screen: &screen,
display: display,
}
inj.cacheKeyboardMapping()
inj.resolveClipboardTool()
log.Infof("X11 input injector ready (display=%s)", display)
return inj, nil
}
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
keycode := x.keysymToKeycode(keysym)
if keycode == 0 {
return
}
var eventType byte
if down {
eventType = xproto.KeyPress
} else {
eventType = xproto.KeyRelease
}
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
}
// InjectPointer simulates mouse movement and button events.
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
// Scale to actual screen coordinates.
screenW := int(x.screen.WidthInPixels)
screenH := int(x.screen.HeightInPixels)
absX := px * screenW / serverW
absY := py * screenH / serverH
// Move pointer.
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
// 4=scrollUp, 5=scrollDown.
type btnMap struct {
rfbBit uint8
x11Btn byte
}
buttons := [...]btnMap{
{0x01, 1}, // left
{0x02, 2}, // middle
{0x04, 3}, // right
{0x08, 4}, // scroll up
{0x10, 5}, // scroll down
}
for _, b := range buttons {
pressed := buttonMask&b.rfbBit != 0
wasPressed := x.lastButtons&b.rfbBit != 0
if b.x11Btn >= 4 {
// Scroll: send press+release on each scroll event.
if pressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
}
} else {
if pressed && !wasPressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
} else if !pressed && wasPressed {
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
}
}
}
x.lastButtons = buttonMask
}
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
func (x *X11InputInjector) cacheKeyboardMapping() {
setup := xproto.Setup(x.conn)
minKeycode := setup.MinKeycode
maxKeycode := setup.MaxKeycode
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
byte(maxKeycode-minKeycode+1)).Reply()
if err != nil {
log.Debugf("cache keyboard mapping: %v", err)
x.keysymMap = make(map[uint32]byte)
return
}
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
for i := int(minKeycode); i <= int(maxKeycode); i++ {
offset := (i - int(minKeycode)) * keysymsPerKeycode
for j := 0; j < keysymsPerKeycode; j++ {
ks := uint32(reply.Keysyms[offset+j])
if ks != 0 {
if _, exists := m[ks]; !exists {
m[ks] = byte(i)
}
}
}
}
x.keysymMap = m
}
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
// Returns 0 if the keysym is not mapped.
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
return x.keysymMap[keysym]
}
// SetClipboard sets the X11 clipboard using xclip or xsel.
func (x *X11InputInjector) SetClipboard(text string) {
if x.clipboardTool == "" {
return
}
var cmd *exec.Cmd
if x.clipboardToolName == "xclip" {
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
} else {
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
}
cmd.Env = x.clipboardEnv()
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
}
}
func (x *X11InputInjector) resolveClipboardTool() {
for _, name := range []string{"xclip", "xsel"} {
path, err := exec.LookPath(name)
if err == nil {
x.clipboardTool = path
x.clipboardToolName = name
log.Debugf("clipboard tool resolved to %s", path)
return
}
}
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
}
// GetClipboard reads the X11 clipboard using xclip or xsel.
func (x *X11InputInjector) GetClipboard() string {
if x.clipboardTool == "" {
return ""
}
var cmd *exec.Cmd
if x.clipboardToolName == "xclip" {
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
} else {
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
}
cmd.Env = x.clipboardEnv()
out, err := cmd.Output()
if err != nil {
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
return ""
}
return string(out)
}
func (x *X11InputInjector) clipboardEnv() []string {
env := []string{"DISPLAY=" + x.display}
if auth := os.Getenv("XAUTHORITY"); auth != "" {
env = append(env, "XAUTHORITY="+auth)
}
return env
}
// Close releases X11 resources.
func (x *X11InputInjector) Close() {
x.conn.Close()
}
var _ InputInjector = (*X11InputInjector)(nil)
var _ ScreenCapturer = (*X11Poller)(nil)

View File

@@ -1,175 +0,0 @@
package server
import (
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// Recording file format:
//
// Header: magic(6) + width(2) + height(2) + startTime(8) + metaLen(4) + metaJSON
// Frames: offsetMs(4) + pngLen(4) + PNG image data
//
// Each frame is a PNG-encoded screenshot. Only changed frames are stored.
const recMagic = "NBVNC\x01"
// RecordingMeta holds metadata written to the recording file header.
type RecordingMeta struct {
User string `json:"user,omitempty"`
RemoteAddr string `json:"remote_addr"`
JWTUser string `json:"jwt_user,omitempty"`
Mode string `json:"mode,omitempty"`
Encrypted bool `json:"encrypted,omitempty"`
EphemeralKey string `json:"ephemeral_key,omitempty"`
}
// vncRecorder writes VNC session frames to a recording file.
type vncRecorder struct {
mu sync.Mutex
file *os.File
startTime time.Time
closed bool
log *log.Entry
prevFrame *image.RGBA
pngEnc *png.Encoder
pngBuf bytes.Buffer
crypto *recCrypto
}
func newVNCRecorder(dir string, width, height int, meta *RecordingMeta, encryptionKey string, logger *log.Entry) (*vncRecorder, error) {
if err := os.MkdirAll(dir, 0700); err != nil {
return nil, fmt.Errorf("create recording dir: %w", err)
}
now := time.Now().UTC()
filename := fmt.Sprintf("%s_vnc.rec", now.Format("20060102-150405"))
filePath := filepath.Join(dir, filename)
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
if err != nil {
return nil, fmt.Errorf("create recording file: %w", err)
}
var crypto *recCrypto
if encryptionKey != "" {
var cryptoErr error
crypto, cryptoErr = newRecCrypto(encryptionKey)
if cryptoErr != nil {
f.Close()
os.Remove(filePath)
return nil, fmt.Errorf("init encryption: %w", cryptoErr)
}
meta.Encrypted = true
meta.EphemeralKey = base64.StdEncoding.EncodeToString(crypto.ephemeralPub)
}
metaJSON, err := json.Marshal(meta)
if err != nil {
f.Close()
os.Remove(filePath)
return nil, fmt.Errorf("marshal meta: %w", err)
}
var hdr [6 + 2 + 2 + 8 + 4]byte
copy(hdr[:6], recMagic)
binary.BigEndian.PutUint16(hdr[6:8], uint16(width))
binary.BigEndian.PutUint16(hdr[8:10], uint16(height))
binary.BigEndian.PutUint64(hdr[10:18], uint64(now.UnixMilli()))
binary.BigEndian.PutUint32(hdr[18:22], uint32(len(metaJSON)))
if _, err := f.Write(hdr[:]); err != nil {
f.Close()
os.Remove(filePath)
return nil, fmt.Errorf("write header: %w", err)
}
if _, err := f.Write(metaJSON); err != nil {
f.Close()
os.Remove(filePath)
return nil, fmt.Errorf("write meta: %w", err)
}
r := &vncRecorder{
file: f,
startTime: now,
log: logger.WithField("recording", filepath.Base(filePath)),
pngEnc: &png.Encoder{CompressionLevel: png.BestSpeed},
crypto: crypto,
}
if crypto != nil {
r.log.Infof("VNC recording started (encrypted): %s", filePath)
} else {
r.log.Infof("VNC recording started: %s", filePath)
}
return r, nil
}
// writeFrame records a screen frame. Only writes if the frame differs from the previous one.
func (r *vncRecorder) writeFrame(img *image.RGBA) {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
if r.prevFrame != nil && bytes.Equal(r.prevFrame.Pix, img.Pix) {
return
}
offsetMs := uint32(time.Since(r.startTime).Milliseconds())
r.pngBuf.Reset()
if err := r.pngEnc.Encode(&r.pngBuf, img); err != nil {
r.log.Debugf("encode PNG frame: %v", err)
return
}
frameData := r.pngBuf.Bytes()
if r.crypto != nil {
frameData = r.crypto.encrypt(frameData)
}
var frameHdr [8]byte
binary.BigEndian.PutUint32(frameHdr[0:4], offsetMs)
binary.BigEndian.PutUint32(frameHdr[4:8], uint32(len(frameData)))
if _, err := r.file.Write(frameHdr[:]); err != nil {
r.log.Debugf("write frame header: %v", err)
return
}
if _, err := r.file.Write(frameData); err != nil {
r.log.Debugf("write frame data: %v", err)
return
}
if r.prevFrame == nil {
r.prevFrame = image.NewRGBA(img.Rect)
}
copy(r.prevFrame.Pix, img.Pix)
}
func (r *vncRecorder) close() {
r.mu.Lock()
defer r.mu.Unlock()
if r.closed {
return
}
r.closed = true
duration := time.Since(r.startTime)
r.log.Infof("VNC recording stopped after %v", duration.Round(time.Millisecond))
r.file.Close()
}

View File

@@ -1,202 +0,0 @@
package server
import (
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"image"
"image/color"
"os"
"path/filepath"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func makeTestImage(w, h int, c color.RGBA) *image.RGBA {
img := image.NewRGBA(image.Rect(0, 0, w, h))
for i := 0; i < len(img.Pix); i += 4 {
img.Pix[i] = c.R
img.Pix[i+1] = c.G
img.Pix[i+2] = c.B
img.Pix[i+3] = c.A
}
return img
}
func TestRecorderWriteAndReadHeader(t *testing.T) {
dir := t.TempDir()
logger := log.WithField("test", t.Name())
meta := &RecordingMeta{
User: "alice",
RemoteAddr: "100.0.1.5:12345",
JWTUser: "google|123",
Mode: "session",
}
rec, err := newVNCRecorder(dir, 800, 600, meta, "", logger)
require.NoError(t, err)
// Write some frames
red := makeTestImage(800, 600, color.RGBA{255, 0, 0, 255})
blue := makeTestImage(800, 600, color.RGBA{0, 0, 255, 255})
rec.writeFrame(red)
rec.writeFrame(red) // duplicate, should be skipped
rec.writeFrame(blue)
rec.close()
// Read back the header
files, err := os.ReadDir(dir)
require.NoError(t, err)
require.Len(t, files, 1)
filePath := filepath.Join(dir, files[0].Name())
header, err := ReadRecordingHeader(filePath)
require.NoError(t, err)
assert.Equal(t, 800, header.Width)
assert.Equal(t, 600, header.Height)
assert.Equal(t, "alice", header.Meta.User)
assert.Equal(t, "100.0.1.5:12345", header.Meta.RemoteAddr)
assert.Equal(t, "google|123", header.Meta.JWTUser)
assert.Equal(t, "session", header.Meta.Mode)
assert.False(t, header.Meta.Encrypted)
// Verify file is valid by checking size is reasonable
fi, err := os.Stat(filePath)
require.NoError(t, err)
assert.Greater(t, fi.Size(), int64(100), "recording should have content")
}
func TestRecorderDuplicateFrameSkip(t *testing.T) {
dir := t.TempDir()
logger := log.WithField("test", t.Name())
rec, err := newVNCRecorder(dir, 100, 100, &RecordingMeta{RemoteAddr: "test"}, "", logger)
require.NoError(t, err)
img := makeTestImage(100, 100, color.RGBA{128, 128, 128, 255})
rec.writeFrame(img)
rec.writeFrame(img) // duplicate
rec.writeFrame(img) // duplicate
rec.close()
files, _ := os.ReadDir(dir)
filePath := filepath.Join(dir, files[0].Name())
// Count frames by parsing
f, err := os.Open(filePath)
require.NoError(t, err)
defer f.Close()
_, err = parseRecHeader(f)
require.NoError(t, err)
frameCount := 0
var hdr [8]byte
for {
if _, err := f.Read(hdr[:]); err != nil {
break
}
pngLen := int64(hdr[4])<<24 | int64(hdr[5])<<16 | int64(hdr[6])<<8 | int64(hdr[7])
f.Seek(pngLen, 1)
frameCount++
}
assert.Equal(t, 1, frameCount, "duplicate frames should be skipped")
}
func TestRecorderEncrypted(t *testing.T) {
dir := t.TempDir()
logger := log.WithField("test", t.Name())
// Generate admin keypair
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
meta := &RecordingMeta{
RemoteAddr: "100.0.1.5:12345",
Mode: "attach",
}
rec, err := newVNCRecorder(dir, 200, 150, meta, adminPubB64, logger)
require.NoError(t, err)
img := makeTestImage(200, 150, color.RGBA{255, 0, 0, 255})
rec.writeFrame(img)
rec.close()
// Read header and verify encryption metadata
files, _ := os.ReadDir(dir)
filePath := filepath.Join(dir, files[0].Name())
header, err := ReadRecordingHeader(filePath)
require.NoError(t, err)
assert.True(t, header.Meta.Encrypted)
assert.NotEmpty(t, header.Meta.EphemeralKey)
assert.Equal(t, 200, header.Width)
assert.Equal(t, 150, header.Height)
}
func TestRecorderEncryptedDecryptRoundtrip(t *testing.T) {
dir := t.TempDir()
logger := log.WithField("test", t.Name())
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
rec, err := newVNCRecorder(dir, 100, 100, &RecordingMeta{RemoteAddr: "test"}, adminPubB64, logger)
require.NoError(t, err)
red := makeTestImage(100, 100, color.RGBA{255, 0, 0, 255})
green := makeTestImage(100, 100, color.RGBA{0, 255, 0, 255})
rec.writeFrame(red)
rec.writeFrame(green)
rec.close()
// Read back and decrypt
files, _ := os.ReadDir(dir)
filePath := filepath.Join(dir, files[0].Name())
header, err := ReadRecordingHeader(filePath)
require.NoError(t, err)
require.True(t, header.Meta.Encrypted)
dec, err := DecryptRecording(adminPrivB64, header.Meta.EphemeralKey)
require.NoError(t, err)
// Read raw frames and decrypt
f, err := os.Open(filePath)
require.NoError(t, err)
defer f.Close()
_, err = parseRecHeader(f)
require.NoError(t, err)
decryptedFrames := 0
var hdr [8]byte
for {
if _, readErr := f.Read(hdr[:]); readErr != nil {
break
}
frameLen := int(hdr[4])<<24 | int(hdr[5])<<16 | int(hdr[6])<<8 | int(hdr[7])
ct := make([]byte, frameLen)
f.Read(ct)
_, err := dec.Decrypt(ct)
require.NoError(t, err, "frame %d decrypt should succeed", decryptedFrames)
decryptedFrames++
}
assert.Equal(t, 2, decryptedFrames)
}

View File

@@ -1,64 +0,0 @@
package server
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"time"
)
// RecordingHeader holds parsed header data from a VNC recording file.
type RecordingHeader struct {
Width int
Height int
StartTime time.Time
Meta RecordingMeta
}
// ReadRecordingHeader parses and returns the recording header without loading frames.
func ReadRecordingHeader(filePath string) (*RecordingHeader, error) {
f, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer f.Close()
return parseRecHeader(f)
}
func parseRecHeader(r io.Reader) (*RecordingHeader, error) {
var hdr [22]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return nil, fmt.Errorf("read header: %w", err)
}
if string(hdr[:6]) != recMagic {
return nil, fmt.Errorf("invalid magic: %x", hdr[:6])
}
width := int(binary.BigEndian.Uint16(hdr[6:8]))
height := int(binary.BigEndian.Uint16(hdr[8:10]))
startMs := int64(binary.BigEndian.Uint64(hdr[10:18]))
metaLen := binary.BigEndian.Uint32(hdr[18:22])
if metaLen > 1<<20 {
return nil, fmt.Errorf("meta too large: %d bytes", metaLen)
}
metaJSON := make([]byte, metaLen)
if _, err := io.ReadFull(r, metaJSON); err != nil {
return nil, fmt.Errorf("read meta: %w", err)
}
var meta RecordingMeta
if err := json.Unmarshal(metaJSON, &meta); err != nil {
return nil, fmt.Errorf("parse meta: %w", err)
}
return &RecordingHeader{
Width: width,
Height: height,
StartTime: time.UnixMilli(startMs),
Meta: meta,
}, nil
}

View File

@@ -1,264 +0,0 @@
package server
import (
"bytes"
"compress/zlib"
"crypto/des"
"encoding/binary"
"image"
)
const (
rfbProtocolVersion = "RFB 003.008\n"
secNone = 1
secVNCAuth = 2
// Client message types.
clientSetPixelFormat = 0
clientSetEncodings = 2
clientFramebufferUpdateRequest = 3
clientKeyEvent = 4
clientPointerEvent = 5
clientCutText = 6
// Server message types.
serverFramebufferUpdate = 0
serverCutText = 3
// Encoding types.
encRaw = 0
encZlib = 6
)
// serverPixelFormat is the default pixel format advertised by the server:
// 32bpp RGBA, big-endian, true-colour, 8 bits per channel.
var serverPixelFormat = [16]byte{
32, // bits-per-pixel
24, // depth
1, // big-endian-flag
1, // true-colour-flag
0, 255, // red-max
0, 255, // green-max
0, 255, // blue-max
16, // red-shift
8, // green-shift
0, // blue-shift
0, 0, 0, // padding
}
// clientPixelFormat holds the negotiated pixel format from the client.
type clientPixelFormat struct {
bpp uint8
bigEndian uint8
rMax uint16
gMax uint16
bMax uint16
rShift uint8
gShift uint8
bShift uint8
}
func defaultClientPixelFormat() clientPixelFormat {
return clientPixelFormat{
bpp: serverPixelFormat[0],
bigEndian: serverPixelFormat[2],
rMax: binary.BigEndian.Uint16(serverPixelFormat[4:6]),
gMax: binary.BigEndian.Uint16(serverPixelFormat[6:8]),
bMax: binary.BigEndian.Uint16(serverPixelFormat[8:10]),
rShift: serverPixelFormat[10],
gShift: serverPixelFormat[11],
bShift: serverPixelFormat[12],
}
}
func parsePixelFormat(pf []byte) clientPixelFormat {
return clientPixelFormat{
bpp: pf[0],
bigEndian: pf[2],
rMax: binary.BigEndian.Uint16(pf[4:6]),
gMax: binary.BigEndian.Uint16(pf[6:8]),
bMax: binary.BigEndian.Uint16(pf[8:10]),
rShift: pf[10],
gShift: pf[11],
bShift: pf[12],
}
}
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
bytesPerPixel := max(int(pf.bpp)/8, 1)
pixelBytes := w * h * bytesPerPixel
buf := make([]byte, 4+12+pixelBytes)
// FramebufferUpdate header.
buf[0] = serverFramebufferUpdate
buf[1] = 0 // padding
binary.BigEndian.PutUint16(buf[2:4], 1)
// Rectangle header.
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
off := 16
stride := img.Stride
for row := y; row < y+h; row++ {
for col := x; col < x+w; col++ {
p := row*stride + col*4
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
rv := uint32(r) * uint32(pf.rMax) / 255
gv := uint32(g) * uint32(pf.gMax) / 255
bv := uint32(b) * uint32(pf.bMax) / 255
pixel := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
if pf.bigEndian != 0 {
for i := range bytesPerPixel {
buf[off+i] = byte(pixel >> uint((bytesPerPixel-1-i)*8))
}
} else {
for i := range bytesPerPixel {
buf[off+i] = byte(pixel >> uint(i*8))
}
}
off += bytesPerPixel
}
}
return buf
}
// vncAuthEncrypt encrypts a 16-byte challenge using the VNC DES scheme.
func vncAuthEncrypt(challenge []byte, password string) []byte {
key := make([]byte, 8)
for i, c := range []byte(password) {
if i >= 8 {
break
}
key[i] = reverseBits(c)
}
block, _ := des.NewCipher(key)
out := make([]byte, 16)
block.Encrypt(out[:8], challenge[:8])
block.Encrypt(out[8:], challenge[8:])
return out
}
func reverseBits(b byte) byte {
var r byte
for range 8 {
r = (r << 1) | (b & 1)
b >>= 1
}
return r
}
// encodeZlibRect encodes a framebuffer region using Zlib compression.
// The zlib stream is continuous for the entire VNC session: noVNC creates
// one inflate context at startup and reuses it for all zlib-encoded rects.
// We must NOT reset the zlib writer between calls.
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, zw *zlib.Writer, zbuf *bytes.Buffer) []byte {
bytesPerPixel := max(int(pf.bpp)/8, 1)
// Clear the output buffer but keep the deflate dictionary intact.
zbuf.Reset()
stride := img.Stride
pixel := make([]byte, bytesPerPixel)
for row := y; row < y+h; row++ {
for col := x; col < x+w; col++ {
p := row*stride + col*4
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
rv := uint32(r) * uint32(pf.rMax) / 255
gv := uint32(g) * uint32(pf.gMax) / 255
bv := uint32(b) * uint32(pf.bMax) / 255
val := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
if pf.bigEndian != 0 {
for i := range bytesPerPixel {
pixel[i] = byte(val >> uint((bytesPerPixel-1-i)*8))
}
} else {
for i := range bytesPerPixel {
pixel[i] = byte(val >> uint(i*8))
}
}
zw.Write(pixel)
}
}
zw.Flush()
compressed := zbuf.Bytes()
// Build the FramebufferUpdate message.
buf := make([]byte, 4+12+4+len(compressed))
buf[0] = serverFramebufferUpdate
buf[1] = 0
binary.BigEndian.PutUint16(buf[2:4], 1) // 1 rectangle
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
copy(buf[20:], compressed)
return buf
}
// diffRects compares two RGBA images and returns a list of dirty rectangles.
// Divides the screen into tiles and checks each for changes.
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
if prev == nil {
return [][4]int{{0, 0, w, h}}
}
var rects [][4]int
for ty := 0; ty < h; ty += tileSize {
th := min(tileSize, h-ty)
for tx := 0; tx < w; tx += tileSize {
tw := min(tileSize, w-tx)
if tileChanged(prev, cur, tx, ty, tw, th) {
rects = append(rects, [4]int{tx, ty, tw, th})
}
}
}
return rects
}
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
stride := prev.Stride
for row := y; row < y+h; row++ {
off := row*stride + x*4
end := off + w*4
prevRow := prev.Pix[off:end]
curRow := cur.Pix[off:end]
if !bytes.Equal(prevRow, curRow) {
return true
}
}
return false
}
// zlibState holds the persistent zlib writer and buffer for a session.
type zlibState struct {
buf *bytes.Buffer
w *zlib.Writer
}
func newZlibState() *zlibState {
buf := &bytes.Buffer{}
w, _ := zlib.NewWriterLevel(buf, zlib.BestSpeed)
return &zlibState{buf: buf, w: w}
}
func (z *zlibState) Close() error {
return z.w.Close()
}

View File

@@ -1,690 +0,0 @@
package server
import (
"context"
"crypto/subtle"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"image"
"io"
"net"
"net/netip"
"strings"
"sync"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Connection modes sent by the client in the session header.
const (
ModeAttach byte = 0 // Capture current display
ModeSession byte = 1 // Virtual session as specified user
)
// RFB security-failure reason codes sent to the client. These prefixes are
// stable so dashboard/noVNC integrations can branch on them without parsing
// free text. Format: "CODE: human message".
const (
RejectCodeJWTMissing = "AUTH_JWT_MISSING"
RejectCodeJWTExpired = "AUTH_JWT_EXPIRED"
RejectCodeJWTInvalid = "AUTH_JWT_INVALID"
RejectCodeAuthForbidden = "AUTH_FORBIDDEN"
RejectCodeAuthConfig = "AUTH_CONFIG"
RejectCodeSessionError = "SESSION_ERROR"
RejectCodeCapturerError = "CAPTURER_ERROR"
RejectCodeUnsupportedOS = "UNSUPPORTED"
RejectCodeBadRequest = "BAD_REQUEST"
)
// EnvVNCDisableDownscale disables any platform-specific framebuffer
// downscaling (e.g. Retina 2:1). Set to 1/true to send the native resolution.
const EnvVNCDisableDownscale = "NB_VNC_DISABLE_DOWNSCALE"
// ScreenCapturer grabs desktop frames for the VNC server.
type ScreenCapturer interface {
// Width returns the current screen width in pixels.
Width() int
// Height returns the current screen height in pixels.
Height() int
// Capture returns the current desktop as an RGBA image.
Capture() (*image.RGBA, error)
}
// InputInjector delivers keyboard and mouse events to the OS.
type InputInjector interface {
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
InjectKey(keysym uint32, down bool)
// InjectPointer simulates mouse movement and button state.
InjectPointer(buttonMask uint8, x, y, serverW, serverH int)
// SetClipboard sets the system clipboard to the given text.
SetClipboard(text string)
// GetClipboard returns the current system clipboard text.
GetClipboard() string
}
// JWTConfig holds JWT validation configuration for VNC auth.
type JWTConfig struct {
Issuer string
KeysLocation string
MaxTokenAge int64
Audiences []string
}
// connectionHeader is sent by the client before the RFB handshake to specify
// the VNC session mode and authenticate.
type connectionHeader struct {
mode byte
username string
jwt string
sessionID uint32 // Windows session ID (0 = console/auto)
}
// Server is the embedded VNC server that listens on the WireGuard interface.
// It supports two operating modes:
// - Direct mode: captures the screen and handles VNC sessions in-process.
// Used when running in a user session with desktop access.
// - Service mode: proxies VNC connections to an agent process spawned in
// the active console session. Used when running as a Windows service in
// Session 0.
//
// Within direct mode, each connection can request one of two session modes
// via the connection header:
// - Attach: capture the current physical display.
// - Session: start a virtual Xvfb display as the requested user.
type Server struct {
capturer ScreenCapturer
injector InputInjector
password string
serviceMode bool
disableAuth bool
localAddr netip.Addr // NetBird WireGuard IP this server is bound to
network netip.Prefix // NetBird overlay network
log *log.Entry
recordingDir string // when set, VNC sessions are recorded to this directory
recordingEncKey string // base64-encoded X25519 public key for encrypting recordings
mu sync.Mutex
listener net.Listener
ctx context.Context
cancel context.CancelFunc
vmgr virtualSessionManager
jwtConfig *JWTConfig
jwtValidator *nbjwt.Validator
jwtExtractor *nbjwt.ClaimsExtractor
authorizer *sshauth.Authorizer
netstackNet *netstack.Net
agentToken []byte // raw token bytes for agent-mode auth
}
// vncSession provides capturer and injector for a virtual display session.
type vncSession interface {
Capturer() ScreenCapturer
Injector() InputInjector
Display() string
ClientConnect()
ClientDisconnect()
}
// virtualSessionManager is implemented by sessionManager on Linux.
type virtualSessionManager interface {
GetOrCreate(username string) (vncSession, error)
StopAll()
}
// New creates a VNC server with the given screen capturer and input injector.
func New(capturer ScreenCapturer, injector InputInjector, password string) *Server {
return &Server{
capturer: capturer,
injector: injector,
password: password,
authorizer: sshauth.NewAuthorizer(),
log: log.WithField("component", "vnc-server"),
}
}
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
func (s *Server) SetServiceMode(enabled bool) {
s.serviceMode = enabled
}
// SetJWTConfig configures JWT authentication for VNC connections.
// Pass nil to disable JWT (public mode).
func (s *Server) SetJWTConfig(config *JWTConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtConfig = config
s.jwtValidator = nil
s.jwtExtractor = nil
}
// SetDisableAuth disables authentication entirely.
func (s *Server) SetDisableAuth(disable bool) {
s.disableAuth = disable
}
// SetAgentToken sets a hex-encoded token that must be presented by incoming
// connections before any VNC data. Used in agent mode to verify that only the
// trusted service process connects.
func (s *Server) SetAgentToken(hexToken string) {
if hexToken == "" {
return
}
b, err := hex.DecodeString(hexToken)
if err != nil {
s.log.Warnf("invalid agent token: %v", err)
return
}
s.agentToken = b
}
// SetNetstackNet sets the netstack network for userspace-only listening.
// When set, the VNC server listens via netstack instead of a real OS socket.
func (s *Server) SetNetstackNet(n *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = n
}
// SetRecordingDir enables VNC session recording to the given directory.
func (s *Server) SetRecordingDir(dir string) {
s.recordingDir = dir
}
// SetRecordingEncryptionKey sets the base64-encoded X25519 public key for encrypting recordings.
func (s *Server) SetRecordingEncryptionKey(key string) {
s.recordingEncKey = key
}
// UpdateVNCAuth updates the fine-grained authorization configuration.
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// Start begins listening for VNC connections on the given address.
// network is the NetBird overlay prefix used to validate connection sources.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return fmt.Errorf("server already running")
}
s.ctx, s.cancel = context.WithCancel(ctx)
s.vmgr = s.platformSessionManager()
s.localAddr = addr.Addr()
s.network = network
var listener net.Listener
var listenDesc string
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
if err != nil {
return fmt.Errorf("listen on netstack %s: %w", addr, err)
}
listener = ln
listenDesc = fmt.Sprintf("netstack %s", addr)
} else {
tcpAddr := net.TCPAddrFromAddrPort(addr)
ln, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
listener = ln
listenDesc = addr.String()
}
s.listener = listener
if s.serviceMode {
s.platformInit()
}
if s.serviceMode {
go s.serviceAcceptLoop()
} else {
go s.acceptLoop()
}
s.log.Infof("started on %s (service_mode=%v)", listenDesc, s.serviceMode)
return nil
}
// Stop shuts down the server and closes all connections.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
s.cancel = nil
}
if s.vmgr != nil {
s.vmgr.StopAll()
}
if c, ok := s.capturer.(interface{ Close() }); ok {
c.Close()
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
if err != nil {
return fmt.Errorf("close VNC listener: %w", err)
}
}
s.log.Info("stopped")
return nil
}
// acceptLoop handles VNC connections directly (user session mode).
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
}
s.log.Debugf("accept VNC connection: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) validateCapturer(cap ScreenCapturer) error {
// Quick check first: if already ready, return immediately.
if cap.Width() > 0 && cap.Height() > 0 {
return nil
}
// Capturer not ready: poke any retry loop that supports it so it doesn't
// wait out its full backoff (e.g. macOS waiting for Screen Recording).
if w, ok := cap.(interface{ Wake() }); ok {
w.Wake()
}
// Wait up to 5s for the capturer to become ready.
for range 50 {
time.Sleep(100 * time.Millisecond)
if cap.Width() > 0 && cap.Height() > 0 {
return nil
}
}
return errors.New("no display available (check X11 on Linux or Screen Recording permission on macOS)")
}
// isAllowedSource rejects connections from outside the NetBird overlay network
// and from the local WireGuard IP (prevents local privilege escalation).
// Matches the SSH server's connectionValidator logic.
func (s *Server) isAllowedSource(addr net.Addr) bool {
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
s.log.Warnf("connection rejected: non-TCP address %s", addr)
return false
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
s.log.Warnf("connection rejected: invalid remote IP %s", tcpAddr.IP)
return false
}
remoteIP = remoteIP.Unmap()
if remoteIP.IsLoopback() && s.localAddr.IsLoopback() {
return true
}
if remoteIP == s.localAddr {
s.log.Warnf("connection rejected from own IP %s", remoteIP)
return false
}
if s.network.IsValid() && !s.network.Contains(remoteIP) {
s.log.Warnf("connection rejected from non-NetBird IP %s", remoteIP)
return false
}
return true
}
func (s *Server) handleConnection(conn net.Conn) {
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
conn.Close()
return
}
if len(s.agentToken) > 0 {
buf := make([]byte, len(s.agentToken))
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
connLog.Debugf("set agent token deadline: %v", err)
conn.Close()
return
}
if _, err := io.ReadFull(conn, buf); err != nil {
connLog.Warnf("agent auth: read token: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{}) //nolint:errcheck
if subtle.ConstantTimeCompare(buf, s.agentToken) != 1 {
connLog.Warn("agent auth: invalid token, rejecting")
conn.Close()
return
}
}
header, err := readConnectionHeader(conn)
if err != nil {
connLog.Warnf("read connection header: %v", err)
conn.Close()
return
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
connLog.Warn("auth rejected: no identity provider configured")
return
}
jwtUserID, err := s.authenticateJWT(header)
if err != nil {
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
connLog.Warnf("auth rejected: %v", err)
return
}
connLog = connLog.WithField("jwt_user", jwtUserID)
}
var capturer ScreenCapturer
var injector InputInjector
switch header.mode {
case ModeSession:
if s.vmgr == nil {
rejectConnection(conn, codeMessage(RejectCodeUnsupportedOS, "virtual sessions not supported on this platform"))
connLog.Warn("session rejected: not supported on this platform")
return
}
if header.username == "" {
rejectConnection(conn, codeMessage(RejectCodeBadRequest, "session mode requires a username"))
connLog.Warn("session rejected: no username provided")
return
}
vs, err := s.vmgr.GetOrCreate(header.username)
if err != nil {
rejectConnection(conn, codeMessage(RejectCodeSessionError, fmt.Sprintf("create virtual session: %v", err)))
connLog.Warnf("create virtual session for %s: %v", header.username, err)
return
}
capturer = vs.Capturer()
injector = vs.Injector()
vs.ClientConnect()
defer vs.ClientDisconnect()
connLog = connLog.WithField("vnc_user", header.username)
connLog.Infof("session mode: user=%s display=%s", header.username, vs.Display())
default:
capturer = s.capturer
injector = s.injector
if cc, ok := capturer.(interface{ ClientConnect() }); ok {
cc.ClientConnect()
}
defer func() {
if cd, ok := capturer.(interface{ ClientDisconnect() }); ok {
cd.ClientDisconnect()
}
}()
}
if err := s.validateCapturer(capturer); err != nil {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, fmt.Sprintf("screen capturer: %v", err)))
connLog.Warnf("capturer not ready: %v", err)
return
}
var rec *vncRecorder
if s.recordingDir != "" {
mode := "attach"
if header.mode == ModeSession {
mode = "session"
}
jwtUser, _ := connLog.Data["jwt_user"].(string)
var err error
rec, err = newVNCRecorder(s.recordingDir, capturer.Width(), capturer.Height(), &RecordingMeta{
User: header.username,
RemoteAddr: conn.RemoteAddr().String(),
JWTUser: jwtUser,
Mode: mode,
}, s.recordingEncKey, connLog)
if err != nil {
connLog.Warnf("start VNC recording: %v", err)
}
}
sess := &session{
conn: conn,
capturer: capturer,
injector: injector,
serverW: capturer.Width(),
serverH: capturer.Height(),
password: s.password,
log: connLog,
recorder: rec,
}
sess.serve()
}
// codeMessage formats a stable reject code with a human-readable message.
// Dashboards split on the first ": " to recover the code without parsing the
// free-text suffix.
func codeMessage(code, msg string) string {
return code + ": " + msg
}
// jwtErrorCode maps a JWT auth error to a stable reject code.
func jwtErrorCode(err error) string {
if err == nil {
return RejectCodeJWTInvalid
}
if errors.Is(err, nbjwt.ErrTokenExpired) {
return RejectCodeJWTExpired
}
msg := err.Error()
switch {
case strings.Contains(msg, "JWT required but not provided"):
return RejectCodeJWTMissing
case strings.Contains(msg, "authorize") || strings.Contains(msg, "not authorized"):
return RejectCodeAuthForbidden
default:
return RejectCodeJWTInvalid
}
}
// rejectConnection sends a minimal RFB handshake with a security failure
// reason, so VNC clients display the error message instead of a generic
// "unexpected disconnect."
func rejectConnection(conn net.Conn, reason string) {
defer conn.Close()
// RFB 3.8 server version.
io.WriteString(conn, "RFB 003.008\n")
// Read client version (12 bytes), ignore errors.
var clientVer [12]byte
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
io.ReadFull(conn, clientVer[:])
conn.SetReadDeadline(time.Time{})
// Send 0 security types = connection failed, followed by reason.
msg := []byte(reason)
buf := make([]byte, 1+4+len(msg))
buf[0] = 0 // 0 security types = failure
binary.BigEndian.PutUint32(buf[1:5], uint32(len(msg)))
copy(buf[5:], msg)
conn.Write(buf)
}
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
// authenticateJWT validates the JWT from the connection header and checks
// authorization. For attach mode, just checks membership in the authorized
// user list. For session mode, additionally validates the OS user mapping.
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
if header.jwt == "" {
return "", fmt.Errorf("JWT required but not provided")
}
s.mu.Lock()
if err := s.ensureJWTValidator(); err != nil {
s.mu.Unlock()
return "", fmt.Errorf("initialize JWT validator: %w", err)
}
validator := s.jwtValidator
extractor := s.jwtExtractor
s.mu.Unlock()
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
if err != nil {
return "", fmt.Errorf("validate JWT: %w", err)
}
if err := s.checkTokenAge(token); err != nil {
return "", err
}
userAuth, err := extractor.ToUserAuth(token)
if err != nil {
return "", fmt.Errorf("extract user from JWT: %w", err)
}
if userAuth.UserId == "" {
return "", fmt.Errorf("JWT has no user ID")
}
switch header.mode {
case ModeSession:
// Session mode: check user + OS username mapping.
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
}
default:
// Attach mode: just check user is in the authorized list (wildcard OS user).
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
return "", fmt.Errorf("user not authorized for VNC: %w", err)
}
}
return userAuth.UserId, nil
}
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
func (s *Server) ensureJWTValidator() error {
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
if s.jwtConfig == nil {
return fmt.Errorf("no JWT config")
}
s.jwtValidator = nbjwt.NewValidator(
s.jwtConfig.Issuer,
s.jwtConfig.Audiences,
s.jwtConfig.KeysLocation,
false,
)
opts := []nbjwt.ClaimsExtractorOption{nbjwt.WithAudience(s.jwtConfig.Audiences[0])}
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
opts = append(opts, nbjwt.WithUserIDClaim(claim))
}
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
return nil
}
func (s *Server) checkTokenAge(token *gojwt.Token) error {
maxAge := defaultJWTMaxTokenAge
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
maxAge = int(s.jwtConfig.MaxTokenAge)
}
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
}
// readConnectionHeader reads the NetBird VNC session header from the connection.
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
//
// [jwt_len: 2 bytes BE] [jwt: N bytes]
//
// Uses a short timeout: our WASM proxy sends the header immediately after
// connecting. Standard VNC clients don't send anything first (server speaks
// first in RFB), so they time out and get the default attach mode.
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
var hdr [3]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
// Timeout or error: assume no header, use attach mode.
return &connectionHeader{mode: ModeAttach}, nil
}
// Restore a longer deadline for reading variable-length fields.
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
mode := hdr[0]
usernameLen := binary.BigEndian.Uint16(hdr[1:3])
var username string
if usernameLen > 0 {
if usernameLen > 256 {
return nil, fmt.Errorf("username too long: %d", usernameLen)
}
buf := make([]byte, usernameLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read username: %w", err)
}
username = string(buf)
}
// Read JWT token length and data.
var jwtLenBuf [2]byte
var jwtToken string
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
if jwtLen > 0 && jwtLen < 8192 {
buf := make([]byte, jwtLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read JWT: %w", err)
}
jwtToken = string(buf)
}
}
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
var sessionID uint32
var sidBuf [4]byte
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
sessionID = binary.BigEndian.Uint32(sidBuf[:])
}
return &connectionHeader{mode: mode, username: username, jwt: jwtToken, sessionID: sessionID}, nil
}

View File

@@ -1,15 +0,0 @@
//go:build darwin && !ios
package server
func (s *Server) platformInit() {}
// serviceAcceptLoop is not supported on macOS.
func (s *Server) serviceAcceptLoop() {
s.log.Warn("service mode not supported on macOS, falling back to direct mode")
s.acceptLoop()
}
func (s *Server) platformSessionManager() virtualSessionManager {
return nil
}

View File

@@ -1,15 +0,0 @@
//go:build !windows && !darwin && !freebsd && !(linux && !android)
package server
func (s *Server) platformInit() {}
// serviceAcceptLoop is not supported on non-Windows platforms.
func (s *Server) serviceAcceptLoop() {
s.log.Warn("service mode not supported on this platform, falling back to direct mode")
s.acceptLoop()
}
func (s *Server) platformSessionManager() virtualSessionManager {
return nil
}

View File

@@ -1,136 +0,0 @@
package server
import (
"encoding/binary"
"image"
"io"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testCapturer returns a 100x100 image for test sessions.
type testCapturer struct{}
func (t *testCapturer) Width() int { return 100 }
func (t *testCapturer) Height() int { return 100 }
func (t *testCapturer) Capture() (*image.RGBA, error) { return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil }
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
t.Helper()
srv := New(&testCapturer{}, &StubInputInjector{}, "")
srv.SetDisableAuth(disableAuth)
if jwtConfig != nil {
srv.SetJWTConfig(jwtConfig)
}
addr := netip.MustParseAddrPort("127.0.0.1:0")
network := netip.MustParsePrefix("127.0.0.0/8")
require.NoError(t, srv.Start(t.Context(), addr, network))
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
srv.localAddr = netip.MustParseAddr("10.99.99.1")
t.Cleanup(func() { _ = srv.Stop() })
return srv.listener.Addr(), srv
}
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
addr, _ := startTestServer(t, false, nil)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header: attach mode, no username, no JWT.
header := []byte{ModeAttach, 0, 0, 0, 0}
_, err = conn.Write(header)
require.NoError(t, err)
// Server should send RFB version then security failure.
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
assert.Equal(t, "RFB 003.008\n", string(version[:]))
// Write client version to proceed through handshake.
_, err = conn.Write(version[:])
require.NoError(t, err)
// Read security types: 0 means failure, followed by reason.
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
var reasonLen [4]byte
_, err = io.ReadFull(conn, reasonLen[:])
require.NoError(t, err)
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
_, err = io.ReadFull(conn, reason)
require.NoError(t, err)
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
}
func TestAuthDisabled_AllowsConnection(t *testing.T) {
addr, _ := startTestServer(t, true, nil)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header: attach mode, no username, no JWT.
header := []byte{ModeAttach, 0, 0, 0, 0}
_, err = conn.Write(header)
require.NoError(t, err)
// Server should send RFB version.
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
assert.Equal(t, "RFB 003.008\n", string(version[:]))
// Write client version.
_, err = conn.Write(version[:])
require.NoError(t, err)
// Should get security types (not 0 = failure).
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
}
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
addr, _ := startTestServer(t, false, &JWTConfig{
Issuer: "https://example.com",
KeysLocation: "https://example.com/.well-known/jwks.json",
Audiences: []string{"test"},
})
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header with empty JWT.
header := []byte{ModeAttach, 0, 0, 0, 0}
_, err = conn.Write(header)
require.NoError(t, err)
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
_, err = conn.Write(version[:])
require.NoError(t, err)
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
}

View File

@@ -1,222 +0,0 @@
//go:build windows
package server
import (
"bytes"
"io"
"net"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var (
sasDLL = windows.NewLazySystemDLL("sas.dll")
procSendSAS = sasDLL.NewProc("SendSAS")
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
)
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
// local processes from triggering the Secure Attention Sequence.
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
// to the interactive user (IU) so the VNC agent in the console session can
// signal it. Other local users and network users are denied.
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
if err != nil {
return nil, err
}
var sd uintptr
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
uintptr(unsafe.Pointer(sddl)),
1, // SDDL_REVISION_1
uintptr(unsafe.Pointer(&sd)),
0,
)
if r == 0 {
return nil, lerr
}
return &windows.SecurityAttributes{
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
InheritHandle: 0,
}, nil
}
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
// SendSAS silently does nothing on most Windows editions.
func enableSoftwareSAS() {
key, _, err := registry.CreateKey(
registry.LOCAL_MACHINE,
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
registry.SET_VALUE,
)
if err != nil {
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
return
}
defer key.Close()
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
log.Warnf("set SoftwareSASGeneration: %v", err)
return
}
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
}
// startSASListener creates a named event with a restricted DACL and waits for
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
// Only SYSTEM processes can open the event.
func startSASListener() {
enableSoftwareSAS()
namePtr, err := windows.UTF16PtrFromString(sasEventName)
if err != nil {
log.Warnf("SAS listener UTF16: %v", err)
return
}
sa, err := sasSecurityAttributes()
if err != nil {
log.Warnf("build SAS security descriptor: %v", err)
return
}
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
if err != nil {
log.Warnf("SAS CreateEvent: %v", err)
return
}
log.Info("SAS listener ready (Session 0)")
go func() {
defer windows.CloseHandle(ev)
for {
ret, _ := windows.WaitForSingleObject(ev, windows.INFINITE)
if ret == windows.WAIT_OBJECT_0 {
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
if r == 0 {
log.Warnf("SendSAS: %v", sasErr)
} else {
log.Info("SendSAS called from Session 0")
}
}
}
}()
}
// enablePrivilege enables a named privilege on the current process token.
func enablePrivilege(name string) error {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
return err
}
defer token.Close()
var luid windows.LUID
namePtr, _ := windows.UTF16PtrFromString(name)
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
return err
}
tp := windows.Tokenprivileges{PrivilegeCount: 1}
tp.Privileges[0].Luid = luid
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
}
func (s *Server) platformSessionManager() virtualSessionManager {
return nil
}
// platformInit starts the SAS listener and enables privileges needed for
// Session 0 operations (agent spawning, SendSAS).
func (s *Server) platformInit() {
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
if err := enablePrivilege(priv); err != nil {
log.Debugf("enable %s: %v", priv, err)
}
}
startSASListener()
}
// serviceAcceptLoop runs in Session 0. It validates source IP and
// authenticates via JWT before proxying connections to the user-session agent.
func (s *Server) serviceAcceptLoop() {
sm := newSessionManager(agentPort)
go sm.run()
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%s", agentPort)
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
sm.Stop()
return
default:
}
s.log.Debugf("accept VNC connection: %v", err)
continue
}
go s.handleServiceConnection(conn, sm)
}
}
// handleServiceConnection validates the source IP and JWT, then proxies
// the connection (with header bytes replayed) to the agent.
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
conn.Close()
return
}
var headerBuf bytes.Buffer
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := readConnectionHeader(teeConn)
if err != nil {
connLog.Debugf("read connection header: %v", err)
conn.Close()
return
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
connLog.Warn("auth rejected: no identity provider configured")
return
}
if _, err := s.authenticateJWT(header); err != nil {
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
connLog.Warnf("auth rejected: %v", err)
return
}
}
// Replay buffered header bytes + remaining stream to the agent.
replayConn := &prefixConn{
Reader: io.MultiReader(&headerBuf, conn),
Conn: conn,
}
proxyToAgent(replayConn, agentPort, sm.AuthToken())
}
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
type prefixConn struct {
io.Reader
net.Conn
}
func (p *prefixConn) Read(b []byte) (int, error) {
return p.Reader.Read(b)
}

View File

@@ -1,15 +0,0 @@
//go:build (linux && !android) || freebsd
package server
func (s *Server) platformInit() {}
// serviceAcceptLoop is not supported on Linux.
func (s *Server) serviceAcceptLoop() {
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
s.acceptLoop()
}
func (s *Server) platformSessionManager() virtualSessionManager {
return newSessionManager(s.log)
}

View File

@@ -1,451 +0,0 @@
package server
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"image"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
readDeadline = 60 * time.Second
maxCutTextBytes = 1 << 20 // 1 MiB
)
const tileSize = 64 // pixels per tile for dirty-rect detection
type session struct {
conn net.Conn
capturer ScreenCapturer
injector InputInjector
serverW int
serverH int
password string
log *log.Entry
recorder *vncRecorder
writeMu sync.Mutex
pf clientPixelFormat
useZlib bool
zlib *zlibState
prevFrame *image.RGBA
idleFrames int
}
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
// serve runs the full RFB session lifecycle.
func (s *session) serve() {
defer s.conn.Close()
if s.recorder != nil {
defer s.recorder.close()
}
s.pf = defaultClientPixelFormat()
if err := s.handshake(); err != nil {
s.log.Warnf("handshake with %s: %v", s.addr(), err)
return
}
s.log.Infof("client connected: %s", s.addr())
done := make(chan struct{})
defer close(done)
go s.clipboardPoll(done)
if err := s.messageLoop(); err != nil && err != io.EOF {
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
} else {
s.log.Infof("client disconnected: %s", s.addr())
}
}
// clipboardPoll periodically checks the server-side clipboard and sends
// changes to the VNC client. Only runs during active sessions.
func (s *session) clipboardPoll(done <-chan struct{}) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
var lastClip string
for {
select {
case <-done:
return
case <-ticker.C:
text := s.injector.GetClipboard()
if len(text) > maxCutTextBytes {
text = text[:maxCutTextBytes]
}
if text != "" && text != lastClip {
lastClip = text
if err := s.sendServerCutText(text); err != nil {
s.log.Debugf("send clipboard to client: %v", err)
return
}
}
}
}
}
func (s *session) handshake() error {
// Send protocol version.
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
return fmt.Errorf("send version: %w", err)
}
// Read client version.
var clientVer [12]byte
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
return fmt.Errorf("read client version: %w", err)
}
// Send supported security types.
if err := s.sendSecurityTypes(); err != nil {
return err
}
// Read chosen security type.
var secType [1]byte
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
return fmt.Errorf("read security type: %w", err)
}
if err := s.handleSecurity(secType[0]); err != nil {
return err
}
// Read ClientInit.
var clientInit [1]byte
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
return fmt.Errorf("read ClientInit: %w", err)
}
return s.sendServerInit()
}
func (s *session) sendSecurityTypes() error {
if s.password == "" {
_, err := s.conn.Write([]byte{1, secNone})
return err
}
_, err := s.conn.Write([]byte{1, secVNCAuth})
return err
}
func (s *session) handleSecurity(secType byte) error {
switch secType {
case secVNCAuth:
return s.doVNCAuth()
case secNone:
return binary.Write(s.conn, binary.BigEndian, uint32(0))
default:
return fmt.Errorf("unsupported security type: %d", secType)
}
}
func (s *session) doVNCAuth() error {
challenge := make([]byte, 16)
if _, err := rand.Read(challenge); err != nil {
return fmt.Errorf("generate challenge: %w", err)
}
if _, err := s.conn.Write(challenge); err != nil {
return fmt.Errorf("send challenge: %w", err)
}
response := make([]byte, 16)
if _, err := io.ReadFull(s.conn, response); err != nil {
return fmt.Errorf("read auth response: %w", err)
}
var result uint32
if s.password != "" {
expected := vncAuthEncrypt(challenge, s.password)
if !bytes.Equal(expected, response) {
result = 1
}
}
if err := binary.Write(s.conn, binary.BigEndian, result); err != nil {
return fmt.Errorf("send auth result: %w", err)
}
if result != 0 {
msg := "authentication failed"
_ = binary.Write(s.conn, binary.BigEndian, uint32(len(msg)))
_, _ = s.conn.Write([]byte(msg))
return fmt.Errorf("authentication failed from %s", s.addr())
}
return nil
}
func (s *session) sendServerInit() error {
name := []byte("NetBird VNC")
buf := make([]byte, 0, 4+16+4+len(name))
// Framebuffer width and height.
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
// Server pixel format.
buf = append(buf, serverPixelFormat[:]...)
// Desktop name.
buf = append(buf,
byte(len(name)>>24), byte(len(name)>>16),
byte(len(name)>>8), byte(len(name)),
)
buf = append(buf, name...)
_, err := s.conn.Write(buf)
return err
}
func (s *session) messageLoop() error {
for {
var msgType [1]byte
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
return fmt.Errorf("set deadline: %w", err)
}
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
return err
}
_ = s.conn.SetDeadline(time.Time{})
switch msgType[0] {
case clientSetPixelFormat:
if err := s.handleSetPixelFormat(); err != nil {
return err
}
case clientSetEncodings:
if err := s.handleSetEncodings(); err != nil {
return err
}
case clientFramebufferUpdateRequest:
if err := s.handleFBUpdateRequest(); err != nil {
return err
}
case clientKeyEvent:
if err := s.handleKeyEvent(); err != nil {
return err
}
case clientPointerEvent:
if err := s.handlePointerEvent(); err != nil {
return err
}
case clientCutText:
if err := s.handleCutText(); err != nil {
return err
}
default:
return fmt.Errorf("unknown client message type: %d", msgType[0])
}
}
}
func (s *session) handleSetPixelFormat() error {
var buf [19]byte // 3 padding + 16 pixel format
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
return fmt.Errorf("read SetPixelFormat: %w", err)
}
s.pf = parsePixelFormat(buf[3:19])
return nil
}
func (s *session) handleSetEncodings() error {
var header [3]byte // 1 padding + 2 number-of-encodings
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
return fmt.Errorf("read SetEncodings header: %w", err)
}
numEnc := binary.BigEndian.Uint16(header[1:3])
buf := make([]byte, int(numEnc)*4)
if _, err := io.ReadFull(s.conn, buf); err != nil {
return err
}
// Check if client supports zlib encoding.
for i := range int(numEnc) {
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
if enc == encZlib {
s.useZlib = true
if s.zlib == nil {
s.zlib = newZlibState()
}
s.log.Debugf("client supports zlib encoding")
break
}
}
return nil
}
func (s *session) handleFBUpdateRequest() error {
var req [9]byte
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
return fmt.Errorf("read FBUpdateRequest: %w", err)
}
incremental := req[0]
img, err := s.capturer.Capture()
if err != nil {
return fmt.Errorf("capture screen: %w", err)
}
if s.recorder != nil {
s.recorder.writeFrame(img)
}
if incremental == 1 && s.prevFrame != nil {
rects := diffRects(s.prevFrame, img, s.serverW, s.serverH, tileSize)
if len(rects) == 0 {
// Nothing changed. Back off briefly before responding to reduce
// CPU usage when the screen is static. The client re-requests
// immediately after receiving our empty response, so without
// this delay we'd spin at ~1000fps checking for changes.
s.idleFrames++
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
time.Sleep(time.Duration(delay) * time.Millisecond)
s.savePrevFrame(img)
return s.sendEmptyUpdate()
}
s.idleFrames = 0
s.savePrevFrame(img)
return s.sendDirtyRects(img, rects)
}
// Full update.
s.idleFrames = 0
s.savePrevFrame(img)
return s.sendFullUpdate(img)
}
// savePrevFrame copies img's pixel data into prevFrame. This is necessary
// because some capturers (DXGI) reuse the same image buffer across calls,
// so a simple pointer assignment would make prevFrame alias the live buffer
// and diffRects would always see zero changes.
func (s *session) savePrevFrame(img *image.RGBA) {
if s.prevFrame == nil || s.prevFrame.Rect != img.Rect {
s.prevFrame = image.NewRGBA(img.Rect)
}
copy(s.prevFrame.Pix, img.Pix)
}
// sendEmptyUpdate sends a FramebufferUpdate with zero rectangles.
func (s *session) sendEmptyUpdate() error {
var buf [4]byte
buf[0] = serverFramebufferUpdate
s.writeMu.Lock()
_, err := s.conn.Write(buf[:])
s.writeMu.Unlock()
return err
}
func (s *session) sendFullUpdate(img *image.RGBA) error {
w, h := s.serverW, s.serverH
var buf []byte
if s.useZlib && s.zlib != nil {
buf = encodeZlibRect(img, s.pf, 0, 0, w, h, s.zlib.w, s.zlib.buf)
} else {
buf = encodeRawRect(img, s.pf, 0, 0, w, h)
}
s.writeMu.Lock()
_, err := s.conn.Write(buf)
s.writeMu.Unlock()
return err
}
func (s *session) sendDirtyRects(img *image.RGBA, rects [][4]int) error {
// Build a multi-rectangle FramebufferUpdate.
// Header: type(1) + padding(1) + numRects(2)
header := make([]byte, 4)
header[0] = serverFramebufferUpdate
binary.BigEndian.PutUint16(header[2:4], uint16(len(rects)))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if _, err := s.conn.Write(header); err != nil {
return err
}
for _, r := range rects {
x, y, w, h := r[0], r[1], r[2], r[3]
var rectBuf []byte
if s.useZlib && s.zlib != nil {
rectBuf = encodeZlibRect(img, s.pf, x, y, w, h, s.zlib.w, s.zlib.buf)
// encodeZlibRect includes its own FBUpdate header for 1 rect.
// For multi-rect, we need just the rect data without the FBUpdate header.
// Skip the 4-byte FBUpdate header since we already sent ours.
rectBuf = rectBuf[4:]
} else {
rectBuf = encodeRawRect(img, s.pf, x, y, w, h)
rectBuf = rectBuf[4:] // skip FBUpdate header
}
if _, err := s.conn.Write(rectBuf); err != nil {
return err
}
}
return nil
}
func (s *session) handleKeyEvent() error {
var data [7]byte
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read KeyEvent: %w", err)
}
down := data[0] == 1
keysym := binary.BigEndian.Uint32(data[3:7])
s.injector.InjectKey(keysym, down)
return nil
}
func (s *session) handlePointerEvent() error {
var data [5]byte
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read PointerEvent: %w", err)
}
buttonMask := data[0]
x := int(binary.BigEndian.Uint16(data[1:3]))
y := int(binary.BigEndian.Uint16(data[3:5]))
s.injector.InjectPointer(buttonMask, x, y, s.serverW, s.serverH)
return nil
}
func (s *session) handleCutText() error {
var header [7]byte // 3 padding + 4 length
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
return fmt.Errorf("read CutText header: %w", err)
}
length := binary.BigEndian.Uint32(header[3:7])
if length > maxCutTextBytes {
return fmt.Errorf("cut text too large: %d bytes", length)
}
buf := make([]byte, length)
if _, err := io.ReadFull(s.conn, buf); err != nil {
return fmt.Errorf("read CutText payload: %w", err)
}
s.injector.SetClipboard(string(buf))
return nil
}
// sendServerCutText sends clipboard text from the server to the client.
func (s *session) sendServerCutText(text string) error {
data := []byte(text)
buf := make([]byte, 8+len(data))
buf[0] = serverCutText
// buf[1:4] = padding (zero)
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
copy(buf[8:], data)
s.writeMu.Lock()
_, err := s.conn.Write(buf)
s.writeMu.Unlock()
return err
}

View File

@@ -1,79 +0,0 @@
//go:build !windows
package server
import (
"fmt"
"os"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
)
// ShutdownState tracks VNC virtual session processes for crash recovery.
// Persisted by the state manager; on restart, residual processes are killed.
type ShutdownState struct {
// Processes maps a description to its PID (e.g., "xvfb:50" -> 1234).
Processes map[string]int `json:"processes,omitempty"`
}
// Name returns the state name for the state manager.
func (s *ShutdownState) Name() string {
return "vnc_sessions_state"
}
// Cleanup kills any residual VNC session processes left from a crash.
func (s *ShutdownState) Cleanup() error {
if len(s.Processes) == 0 {
return nil
}
for desc, pid := range s.Processes {
if pid <= 0 {
continue
}
if !isOurProcess(pid, desc) {
log.Debugf("cleanup:skipping PID %d (%s), not ours", pid, desc)
continue
}
log.Infof("cleanup:killing residual process %d (%s)", pid, desc)
// Kill the process group (negative PID) to get children too.
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
// Try individual process if group kill fails.
syscall.Kill(pid, syscall.SIGKILL)
}
}
s.Processes = nil
return nil
}
// isOurProcess verifies the PID still belongs to a VNC-related process
// by checking /proc/<pid>/cmdline (Linux) or the process name.
func isOurProcess(pid int, desc string) bool {
// Check if the process exists at all.
if err := syscall.Kill(pid, 0); err != nil {
return false
}
// On Linux, verify via /proc cmdline.
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil {
// No /proc (FreeBSD): trust the PID if the process exists.
// PID reuse is unlikely in the short window between crash and restart.
return true
}
cmd := string(cmdline)
// Match against expected process types.
if strings.Contains(desc, "xvfb") || strings.Contains(desc, "xorg") {
return strings.Contains(cmd, "Xvfb") || strings.Contains(cmd, "Xorg")
}
if strings.Contains(desc, "desktop") {
return strings.Contains(cmd, "session") || strings.Contains(cmd, "plasma") ||
strings.Contains(cmd, "gnome") || strings.Contains(cmd, "xfce") ||
strings.Contains(cmd, "dbus-launch")
}
return false
}

View File

@@ -1,37 +0,0 @@
package server
import (
"fmt"
"image"
)
const maxCapturerRetries = 5
// StubCapturer is a placeholder for platforms without screen capture support.
type StubCapturer struct{}
// Width returns 0 on unsupported platforms.
func (c *StubCapturer) Width() int { return 0 }
// Height returns 0 on unsupported platforms.
func (c *StubCapturer) Height() int { return 0 }
// Capture returns an error on unsupported platforms.
func (c *StubCapturer) Capture() (*image.RGBA, error) {
return nil, fmt.Errorf("screen capture not supported on this platform")
}
// StubInputInjector is a placeholder for platforms without input injection support.
type StubInputInjector struct{}
// InjectKey is a no-op on unsupported platforms.
func (s *StubInputInjector) InjectKey(_ uint32, _ bool) {}
// InjectPointer is a no-op on unsupported platforms.
func (s *StubInputInjector) InjectPointer(_ uint8, _, _, _, _ int) {}
// SetClipboard is a no-op on unsupported platforms.
func (s *StubInputInjector) SetClipboard(_ string) {}
// GetClipboard returns empty on unsupported platforms.
func (s *StubInputInjector) GetClipboard() string { return "" }

View File

@@ -1,19 +0,0 @@
//go:build windows
package server
import "unsafe"
// swizzleBGRAtoRGBA swaps B and R channels in a BGRA pixel buffer in-place.
// Operates on uint32 words for throughput: one read-modify-write per pixel.
func swizzleBGRAtoRGBA(pix []byte) {
n := len(pix) / 4
pixels := unsafe.Slice((*uint32)(unsafe.Pointer(&pix[0])), n)
for i := range n {
p := pixels[i]
// p = 0xAABBGGRR (little-endian BGRA in memory: B,G,R,A bytes)
// We want 0xAABBGGRR -> 0xAARRGGBB (RGBA in memory: R,G,B,A bytes)
// Swap byte 0 (B) and byte 2 (R), keep byte 1 (G) and byte 3 (A).
pixels[i] = (p & 0xFF00FF00) | ((p & 0x00FF0000) >> 16) | ((p & 0x000000FF) << 16)
}
}

View File

@@ -1,634 +0,0 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
// VirtualSession manages a virtual X11 display (Xvfb) with a desktop session
// running as a target user. It implements ScreenCapturer and InputInjector by
// delegating to an X11Capturer/X11InputInjector pointed at the virtual display.
const sessionIdleTimeout = 5 * time.Minute
type VirtualSession struct {
mu sync.Mutex
display string
user *user.User
uid uint32
gid uint32
groups []uint32
xvfb *exec.Cmd
desktop *exec.Cmd
poller *X11Poller
injector *X11InputInjector
log *log.Entry
stopped bool
clients int
idleTimer *time.Timer
onIdle func() // called when idle timeout fires or Xvfb dies
}
// StartVirtualSession creates and starts a virtual X11 session for the given user.
// Requires root privileges to create sessions as other users.
func StartVirtualSession(username string, logger *log.Entry) (*VirtualSession, error) {
if os.Getuid() != 0 {
return nil, fmt.Errorf("virtual sessions require root privileges")
}
if _, err := exec.LookPath("Xvfb"); err != nil {
if _, err := exec.LookPath("Xorg"); err != nil {
return nil, fmt.Errorf("neither Xvfb nor Xorg found (install xvfb or xserver-xorg)")
}
if !hasDummyDriver() {
return nil, fmt.Errorf("Xvfb not found and Xorg dummy driver not installed (install xvfb or xf86-video-dummy)")
}
}
u, err := user.Lookup(username)
if err != nil {
return nil, fmt.Errorf("lookup user %s: %w", username, err)
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return nil, fmt.Errorf("parse uid: %w", err)
}
gid, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return nil, fmt.Errorf("parse gid: %w", err)
}
groups, err := supplementaryGroups(u)
if err != nil {
logger.Debugf("supplementary groups for %s: %v", username, err)
}
vs := &VirtualSession{
user: u,
uid: uint32(uid),
gid: uint32(gid),
groups: groups,
log: logger.WithField("vnc_user", username),
}
if err := vs.start(); err != nil {
return nil, err
}
return vs, nil
}
func (vs *VirtualSession) start() error {
display, err := findFreeDisplay()
if err != nil {
return fmt.Errorf("find free display: %w", err)
}
vs.display = display
if err := vs.startXvfb(); err != nil {
return err
}
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", vs.display[1:])
if err := waitForPath(socketPath, 5*time.Second); err != nil {
vs.stopXvfb()
return fmt.Errorf("wait for X11 socket %s: %w", socketPath, err)
}
// Grant the target user access to the display via xhost.
xhostCmd := exec.Command("xhost", "+SI:localuser:"+vs.user.Username)
xhostCmd.Env = []string{"DISPLAY=" + vs.display}
if out, err := xhostCmd.CombinedOutput(); err != nil {
vs.log.Debugf("xhost: %s (%v)", strings.TrimSpace(string(out)), err)
}
vs.poller = NewX11Poller(vs.display)
injector, err := NewX11InputInjector(vs.display)
if err != nil {
vs.stopXvfb()
return fmt.Errorf("create X11 injector for %s: %w", vs.display, err)
}
vs.injector = injector
if err := vs.startDesktop(); err != nil {
vs.injector.Close()
vs.stopXvfb()
return fmt.Errorf("start desktop: %w", err)
}
vs.log.Infof("virtual session started: display=%s user=%s", vs.display, vs.user.Username)
return nil
}
// ClientConnect increments the client count and cancels any idle timer.
func (vs *VirtualSession) ClientConnect() {
vs.mu.Lock()
defer vs.mu.Unlock()
vs.clients++
if vs.idleTimer != nil {
vs.idleTimer.Stop()
vs.idleTimer = nil
}
}
// ClientDisconnect decrements the client count. When the last client
// disconnects, starts an idle timer that destroys the session.
func (vs *VirtualSession) ClientDisconnect() {
vs.mu.Lock()
defer vs.mu.Unlock()
vs.clients--
if vs.clients <= 0 {
vs.clients = 0
vs.log.Infof("no VNC clients connected, session will be destroyed in %s", sessionIdleTimeout)
vs.idleTimer = time.AfterFunc(sessionIdleTimeout, vs.idleExpired)
}
}
// idleExpired is called by the idle timer. It stops the session and
// notifies the session manager via onIdle so it removes us from the map.
func (vs *VirtualSession) idleExpired() {
vs.log.Info("idle timeout reached, destroying virtual session")
vs.Stop()
// onIdle acquires sessionManager.mu; safe because Stop() has released vs.mu.
if vs.onIdle != nil {
vs.onIdle()
}
}
// isAlive returns true if the session is running and its X server socket exists.
func (vs *VirtualSession) isAlive() bool {
vs.mu.Lock()
stopped := vs.stopped
display := vs.display
vs.mu.Unlock()
if stopped {
return false
}
// Verify the X socket still exists on disk.
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", display[1:])
if _, err := os.Stat(socketPath); err != nil {
return false
}
return true
}
// Capturer returns the screen capturer for this virtual session.
func (vs *VirtualSession) Capturer() ScreenCapturer {
return vs.poller
}
// Injector returns the input injector for this virtual session.
func (vs *VirtualSession) Injector() InputInjector {
return vs.injector
}
// Display returns the X11 display string (e.g., ":99").
func (vs *VirtualSession) Display() string {
return vs.display
}
// Stop terminates the virtual session, killing the desktop and Xvfb.
func (vs *VirtualSession) Stop() {
vs.mu.Lock()
defer vs.mu.Unlock()
if vs.stopped {
return
}
vs.stopped = true
if vs.injector != nil {
vs.injector.Close()
}
vs.stopDesktop()
vs.stopXvfb()
vs.log.Info("virtual session stopped")
}
func (vs *VirtualSession) startXvfb() error {
if _, err := exec.LookPath("Xvfb"); err == nil {
return vs.startXvfbDirect()
}
return vs.startXorgDummy()
}
func (vs *VirtualSession) startXvfbDirect() error {
vs.xvfb = exec.Command("Xvfb", vs.display,
"-screen", "0", "1280x800x24",
"-ac",
"-nolisten", "tcp",
)
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
if err := vs.xvfb.Start(); err != nil {
return fmt.Errorf("start Xvfb on %s: %w", vs.display, err)
}
vs.log.Infof("Xvfb started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
go vs.monitorXvfb()
return nil
}
// startXorgDummy starts Xorg with the dummy video driver as a fallback when
// Xvfb is not installed. Most systems with a desktop have Xorg available.
func (vs *VirtualSession) startXorgDummy() error {
confPath := fmt.Sprintf("/tmp/nbvnc-dummy-%s.conf", vs.display[1:])
conf := `Section "Device"
Identifier "dummy"
Driver "dummy"
VideoRam 256000
EndSection
Section "Screen"
Identifier "screen"
Device "dummy"
DefaultDepth 24
SubSection "Display"
Depth 24
Modes "1280x800"
EndSubSection
EndSection
`
if err := os.WriteFile(confPath, []byte(conf), 0644); err != nil {
return fmt.Errorf("write Xorg dummy config: %w", err)
}
vs.xvfb = exec.Command("Xorg", vs.display,
"-config", confPath,
"-noreset",
"-nolisten", "tcp",
"-ac",
)
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
if err := vs.xvfb.Start(); err != nil {
os.Remove(confPath)
return fmt.Errorf("start Xorg dummy on %s: %w", vs.display, err)
}
vs.log.Infof("Xorg (dummy driver) started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
go func() {
vs.monitorXvfb()
os.Remove(confPath)
}()
return nil
}
// monitorXvfb waits for the Xvfb/Xorg process to exit. If it exits
// unexpectedly (not via Stop), the session is marked as dead and the
// onIdle callback fires so the session manager removes it from the map.
// The next GetOrCreate call for this user will create a fresh session.
func (vs *VirtualSession) monitorXvfb() {
if err := vs.xvfb.Wait(); err != nil {
vs.log.Debugf("X server exited: %v", err)
}
vs.mu.Lock()
alreadyStopped := vs.stopped
if !alreadyStopped {
vs.log.Warn("X server exited unexpectedly, marking session as dead")
vs.stopped = true
if vs.idleTimer != nil {
vs.idleTimer.Stop()
vs.idleTimer = nil
}
if vs.injector != nil {
vs.injector.Close()
}
vs.stopDesktop()
}
onIdle := vs.onIdle
vs.mu.Unlock()
if !alreadyStopped && onIdle != nil {
onIdle()
}
}
func (vs *VirtualSession) stopXvfb() {
if vs.xvfb != nil && vs.xvfb.Process != nil {
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGTERM)
time.Sleep(200 * time.Millisecond)
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGKILL)
}
}
func (vs *VirtualSession) startDesktop() error {
session := detectDesktopSession()
// Wrap the desktop command with dbus-launch to provide a session bus.
// Without this, most desktop environments (XFCE, MATE, etc.) fail immediately.
var args []string
if _, err := exec.LookPath("dbus-launch"); err == nil {
args = append([]string{"dbus-launch", "--exit-with-session"}, session...)
} else {
args = session
}
vs.desktop = exec.Command(args[0], args[1:]...)
vs.desktop.Dir = vs.user.HomeDir
vs.desktop.Env = vs.buildUserEnv()
vs.desktop.SysProcAttr = &syscall.SysProcAttr{
Credential: &syscall.Credential{
Uid: vs.uid,
Gid: vs.gid,
Groups: vs.groups,
},
Setsid: true,
Pdeathsig: syscall.SIGTERM,
}
if err := vs.desktop.Start(); err != nil {
return fmt.Errorf("start desktop session (%v): %w", args, err)
}
vs.log.Infof("desktop session started: %v (pid=%d)", args, vs.desktop.Process.Pid)
go func() {
if err := vs.desktop.Wait(); err != nil {
vs.log.Debugf("desktop session exited: %v", err)
}
}()
return nil
}
func (vs *VirtualSession) stopDesktop() {
if vs.desktop != nil && vs.desktop.Process != nil {
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGTERM)
time.Sleep(200 * time.Millisecond)
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGKILL)
}
}
func (vs *VirtualSession) buildUserEnv() []string {
return []string{
"DISPLAY=" + vs.display,
"HOME=" + vs.user.HomeDir,
"USER=" + vs.user.Username,
"LOGNAME=" + vs.user.Username,
"SHELL=" + getUserShell(vs.user.Uid),
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"XDG_RUNTIME_DIR=/run/user/" + vs.user.Uid,
"DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/" + vs.user.Uid + "/bus",
}
}
// detectDesktopSession discovers available desktop sessions from the standard
// /usr/share/xsessions/*.desktop files (FreeDesktop standard, used by all
// display managers). Falls back to a hardcoded list if no .desktop files found.
func detectDesktopSession() []string {
// Scan xsessions directories (Linux: /usr/share, FreeBSD: /usr/local/share).
for _, dir := range []string{"/usr/share/xsessions", "/usr/local/share/xsessions"} {
if cmd := findXSession(dir); cmd != nil {
return cmd
}
}
// Fallback: try common session commands directly.
fallbacks := [][]string{
{"startplasma-x11"},
{"gnome-session"},
{"xfce4-session"},
{"mate-session"},
{"cinnamon-session"},
{"openbox-session"},
{"xterm"},
}
for _, s := range fallbacks {
if _, err := exec.LookPath(s[0]); err == nil {
return s
}
}
return []string{"xterm"}
}
// sessionPriority defines preference order for desktop environments.
// Lower number = higher priority. Unknown sessions get 100.
var sessionPriority = map[string]int{
"plasma": 1, // KDE
"gnome": 2,
"xfce": 3,
"mate": 4,
"cinnamon": 5,
"lxqt": 6,
"lxde": 7,
"budgie": 8,
"openbox": 20,
"fluxbox": 21,
"i3": 22,
"xinit": 50, // generic user session
"lightdm": 50,
"default": 50,
}
func findXSession(dir string) []string {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
type candidate struct {
cmd string
priority int
}
var candidates []candidate
for _, e := range entries {
if !strings.HasSuffix(e.Name(), ".desktop") {
continue
}
data, err := os.ReadFile(filepath.Join(dir, e.Name()))
if err != nil {
continue
}
execCmd := ""
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Exec=") {
execCmd = strings.TrimSpace(strings.TrimPrefix(line, "Exec="))
break
}
}
if execCmd == "" || execCmd == "default" {
continue
}
// Determine priority from the filename or exec command.
pri := 100
lower := strings.ToLower(e.Name() + " " + execCmd)
for keyword, p := range sessionPriority {
if strings.Contains(lower, keyword) && p < pri {
pri = p
}
}
candidates = append(candidates, candidate{cmd: execCmd, priority: pri})
}
if len(candidates) == 0 {
return nil
}
// Pick the highest priority (lowest number).
best := candidates[0]
for _, c := range candidates[1:] {
if c.priority < best.priority {
best = c
}
}
// Verify the binary exists.
parts := strings.Fields(best.cmd)
if _, err := exec.LookPath(parts[0]); err != nil {
return nil
}
return parts
}
// findFreeDisplay scans for an unused X11 display number.
func findFreeDisplay() (string, error) {
for n := 50; n < 200; n++ {
lockFile := fmt.Sprintf("/tmp/.X%d-lock", n)
socketFile := fmt.Sprintf("/tmp/.X11-unix/X%d", n)
if _, err := os.Stat(lockFile); err == nil {
continue
}
if _, err := os.Stat(socketFile); err == nil {
continue
}
return fmt.Sprintf(":%d", n), nil
}
return "", fmt.Errorf("no free X11 display found (checked :50-:199)")
}
// waitForPath polls until a filesystem path exists or the timeout expires.
func waitForPath(path string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if _, err := os.Stat(path); err == nil {
return nil
}
time.Sleep(50 * time.Millisecond)
}
return fmt.Errorf("timeout waiting for %s", path)
}
// getUserShell returns the login shell for the given UID.
func getUserShell(uid string) string {
data, err := os.ReadFile("/etc/passwd")
if err != nil {
return "/bin/sh"
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 7 && fields[2] == uid {
return fields[6]
}
}
return "/bin/sh"
}
// supplementaryGroups returns the supplementary group IDs for a user.
func supplementaryGroups(u *user.User) ([]uint32, error) {
gids, err := u.GroupIds()
if err != nil {
return nil, err
}
var groups []uint32
for _, g := range gids {
id, err := strconv.ParseUint(g, 10, 32)
if err != nil {
continue
}
groups = append(groups, uint32(id))
}
return groups, nil
}
// sessionManager tracks active virtual sessions by username.
type sessionManager struct {
mu sync.Mutex
sessions map[string]*VirtualSession
log *log.Entry
}
func newSessionManager(logger *log.Entry) *sessionManager {
return &sessionManager{
sessions: make(map[string]*VirtualSession),
log: logger,
}
}
// GetOrCreate returns an existing virtual session or creates a new one.
// If a previous session for this user is stopped or its X server died, it is replaced.
func (sm *sessionManager) GetOrCreate(username string) (vncSession, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
if vs, ok := sm.sessions[username]; ok {
if vs.isAlive() {
return vs, nil
}
sm.log.Infof("replacing dead virtual session for %s", username)
vs.Stop()
delete(sm.sessions, username)
}
vs, err := StartVirtualSession(username, sm.log)
if err != nil {
return nil, err
}
vs.onIdle = func() {
sm.mu.Lock()
defer sm.mu.Unlock()
if cur, ok := sm.sessions[username]; ok && cur == vs {
delete(sm.sessions, username)
sm.log.Infof("removed idle virtual session for %s", username)
}
}
sm.sessions[username] = vs
return vs, nil
}
// hasDummyDriver checks common paths for the Xorg dummy video driver.
func hasDummyDriver() bool {
paths := []string{
"/usr/lib/xorg/modules/drivers/dummy_drv.so", // Debian/Ubuntu
"/usr/lib64/xorg/modules/drivers/dummy_drv.so", // RHEL/Fedora
"/usr/local/lib/xorg/modules/drivers/dummy_drv.so", // FreeBSD
"/usr/lib/x86_64-linux-gnu/xorg/modules/drivers/dummy_drv.so", // Debian multiarch
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return true
}
}
return false
}
// StopAll terminates all active virtual sessions.
func (sm *sessionManager) StopAll() {
sm.mu.Lock()
defer sm.mu.Unlock()
for username, vs := range sm.sessions {
vs.Stop()
delete(sm.sessions, username)
sm.log.Infof("stopped virtual session for %s", username)
}
}

View File

@@ -1,50 +0,0 @@
package server
import (
_ "embed"
"fmt"
"net"
"net/http"
"os"
)
//go:embed webplayer.html
var webPlayerHTML []byte
// ServeWebPlayer starts a local HTTP server that serves the recording file
// and an HTML player page. Returns the URL to open.
func ServeWebPlayer(recPath, listenAddr string) (string, error) {
if listenAddr == "" {
listenAddr = "localhost:0"
}
ln, err := net.Listen("tcp", listenAddr)
if err != nil {
return "", fmt.Errorf("listen: %w", err)
}
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write(webPlayerHTML) //nolint:errcheck
})
mux.HandleFunc("/recording.rec", func(w http.ResponseWriter, r *http.Request) {
f, err := os.Open(recPath)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer f.Close()
fi, _ := f.Stat()
w.Header().Set("Content-Type", "application/octet-stream")
http.ServeContent(w, r, "recording.rec", fi.ModTime(), f)
})
url := fmt.Sprintf("http://%s", ln.Addr())
go http.Serve(ln, mux) //nolint:errcheck
return url, nil
}

View File

@@ -1,291 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>NetBird - VNC Session Recording</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { background: #0d1117; color: #e6edf3; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; display: flex; flex-direction: column; height: 100vh; }
#header { background: #161b22; padding: 10px 20px; display: flex; align-items: center; gap: 16px; border-bottom: 1px solid #30363d; flex-wrap: wrap; }
.logo { display: flex; align-items: center; gap: 8px; }
.logo svg { width: 20px; height: 20px; }
.logo-text { font-weight: 600; font-size: 14px; color: #f0f6fc; }
.logo-badge { font-size: 11px; background: #f4722b; color: #fff; padding: 1px 7px; border-radius: 10px; font-weight: 500; }
#rec-info { font-size: 12px; color: #8b949e; display: flex; gap: 6px; flex-wrap: wrap; }
#rec-info span { background: #21262d; padding: 2px 8px; border-radius: 4px; }
#rec-info .label { color: #6e7681; }
#controls { background: #161b22; padding: 6px 20px; display: flex; align-items: center; gap: 10px; border-bottom: 1px solid #30363d; }
#controls button { background: #21262d; color: #e6edf3; border: 1px solid #30363d; padding: 4px 14px; border-radius: 6px; cursor: pointer; font-size: 14px; min-width: 36px; }
#controls button:hover { background: #30363d; border-color: #8b949e; }
#controls button.active { background: #f4722b; border-color: #f4722b; color: #fff; }
#seek { flex: 1; cursor: pointer; accent-color: #f4722b; height: 4px; margin: 0; padding: 0; }
#speed-select { background: #21262d; color: #e6edf3; border: 1px solid #30363d; padding: 3px 6px; border-radius: 4px; font-size: 12px; }
#time { font-size: 12px; font-variant-numeric: tabular-nums; min-width: 90px; text-align: center; color: #8b949e; }
#frame-info { font-size: 11px; color: #6e7681; }
#canvas-wrap { flex: 1; display: flex; align-items: center; justify-content: center; overflow: hidden; background: #010409; }
canvas { max-width: 100%; max-height: 100%; }
#footer { background: #161b22; padding: 5px 20px; font-size: 11px; color: #484f58; border-top: 1px solid #30363d; display: flex; justify-content: space-between; }
#status { font-size: 12px; color: #8b949e; }
</style>
</head>
<body>
<div id="header">
<div class="logo">
<svg width="24" height="18" viewBox="0 0 31 23" fill="none"><path d="M21.4631 0.523438C17.8173 0.857913 16.0028 2.95675 15.3171 4.01871L4.66406 22.4734H17.5163L30.1929 0.523438H21.4631Z" fill="#F68330"/><path d="M17.5265 22.4737L0 3.88525C0 3.88525 19.8177 -1.44128 21.7493 15.1738L17.5265 22.4737Z" fill="#F68330"/><path d="M14.9236 4.70563L9.54688 14.0208L17.5158 22.4747L21.7385 15.158C21.0696 9.44682 18.2851 6.32784 14.9236 4.69727" fill="#F05252"/></svg>
<span class="logo-text">NetBird</span>
<span class="logo-badge">VNC Session Recording</span>
</div>
<div id="rec-info"></div>
</div>
<div id="controls">
<button id="playBtn" onclick="togglePlay()" title="Space">&#9654;</button>
<input type="range" id="seek" min="0" max="1000" value="0" oninput="seekTo(this.value)">
<span id="time">0:00 / 0:00</span>
<select id="speed-select" onchange="setSpeed(this.value)" title="Playback speed">
<option value="0.25">0.25x</option>
<option value="0.5">0.5x</option>
<option value="1" selected>1x</option>
<option value="2">2x</option>
<option value="4">4x</option>
<option value="8">8x</option>
</select>
<span id="frame-info"></span>
<span id="status">Loading...</span>
</div>
<div id="canvas-wrap"><canvas id="canvas"></canvas></div>
<div id="footer">
<span>Space: play/pause | Left/Right: seek 5s | Scroll: speed</span>
<span id="file-info"></span>
</div>
<script>
const canvas = document.getElementById('canvas');
const ctx = canvas.getContext('2d');
const seekBar = document.getElementById('seek');
const timeEl = document.getElementById('time');
const statusEl = document.getElementById('status');
const recInfoEl = document.getElementById('rec-info');
const frameInfoEl = document.getElementById('frame-info');
const fileInfoEl = document.getElementById('file-info');
const playBtn = document.getElementById('playBtn');
let frames = []; // { offsetMs, bitmap }
let header = null;
let playing = false;
let speed = 1;
let startTime = 0;
let pauseOffset = 0;
let currentFrame = 0;
let animId = null;
let durationMs = 0;
function fmt(ms) {
const s = Math.floor(ms / 1000);
const m = Math.floor(s / 60);
const h = Math.floor(m / 60);
if (h > 0) return `${h}:${String(m % 60).padStart(2, '0')}:${String(s % 60).padStart(2, '0')}`;
return `${m}:${String(s % 60).padStart(2, '0')}`;
}
function fmtSize(bytes) {
if (bytes >= 1048576) return (bytes / 1048576).toFixed(1) + ' MB';
if (bytes >= 1024) return (bytes / 1024).toFixed(1) + ' KB';
return bytes + ' B';
}
async function load() {
statusEl.textContent = 'Fetching...';
const resp = await fetch('/recording.rec');
const buf = await resp.arrayBuffer();
const view = new DataView(buf);
const magic = new TextDecoder().decode(new Uint8Array(buf, 0, 6));
if (magic !== 'NBVNC\x01') { statusEl.textContent = 'Invalid recording file'; return; }
const width = view.getUint16(6);
const height = view.getUint16(8);
const startMs = Number(view.getBigUint64(10));
const metaLen = view.getUint32(18);
const metaJSON = new TextDecoder().decode(new Uint8Array(buf, 22, metaLen));
const meta = JSON.parse(metaJSON);
header = { width, height, startMs, meta };
canvas.width = width;
canvas.height = height;
const dateStr = new Date(startMs).toLocaleString();
const parts = [];
if (meta.mode) parts.push(`<span><span class="label">Type:</span> vnc (${meta.mode})</span>`);
if (meta.remote_addr) parts.push(`<span><span class="label">Remote:</span> ${meta.remote_addr}</span>`);
if (meta.jwt_user) parts.push(`<span><span class="label">JWT:</span> ${meta.jwt_user}</span>`);
if (meta.user) parts.push(`<span><span class="label">User:</span> ${meta.user}</span>`);
parts.push(`<span><span class="label">Date:</span> ${dateStr}</span>`);
parts.push(`<span>${width}x${height}</span>`);
recInfoEl.innerHTML = parts.join('');
fileInfoEl.textContent = fmtSize(buf.byteLength);
document.title = `NetBird - VNC Session Recording - ${meta.remote_addr || ''}`;
// Parse frame offsets (fast pass, no decoding)
const rawFrames = [];
let offset = 22 + metaLen;
while (offset + 8 <= buf.byteLength) {
const offsetMs = view.getUint32(offset);
const pngLen = view.getUint32(offset + 4);
offset += 8;
if (offset + pngLen > buf.byteLength) break;
rawFrames.push({ offsetMs, start: offset, length: pngLen });
offset += pngLen;
}
if (rawFrames.length === 0) { statusEl.textContent = 'No frames in recording'; return; }
// Handle encrypted recordings
let decryptKey = null;
if (meta.encrypted) {
const privKeyB64 = prompt('This recording is encrypted.\nPaste your base64-encoded X25519 private key:');
if (!privKeyB64) { statusEl.textContent = 'Decryption key required'; return; }
try {
decryptKey = await deriveDecryptKey(privKeyB64, meta.ephemeral_key);
} catch (e) {
statusEl.textContent = 'Decryption key error: ' + e.message;
return;
}
}
// Decode PNG frames to ImageData. We decode bitmaps in parallel batches,
// then draw them sequentially to avoid OffscreenCanvas races.
const offscreen = new OffscreenCanvas(width, height);
const offCtx = offscreen.getContext('2d');
const batchSize = 20;
for (let i = 0; i < rawFrames.length; i += batchSize) {
const batch = rawFrames.slice(i, i + batchSize);
const bitmaps = await Promise.all(batch.map(async (f, batchIdx) => {
let pngData = new Uint8Array(buf, f.start, f.length);
if (decryptKey) {
const frameIdx = i + batchIdx;
pngData = await decryptFrame(decryptKey, pngData, frameIdx);
}
const blob = new Blob([pngData], { type: 'image/png' });
return createImageBitmap(blob);
}));
for (let j = 0; j < bitmaps.length; j++) {
offCtx.drawImage(bitmaps[j], 0, 0);
bitmaps[j].close();
frames.push({ offsetMs: batch[j].offsetMs, imgData: offCtx.getImageData(0, 0, width, height) });
}
statusEl.textContent = `Loading ${frames.length}/${rawFrames.length}`;
if (i % (batchSize * 3) === 0) await new Promise(r => setTimeout(r, 0));
}
const firstMs = frames[0].offsetMs;
durationMs = frames[frames.length - 1].offsetMs;
seekBar.min = firstMs;
seekBar.max = durationMs;
timeEl.textContent = `0:00 / ${fmt(durationMs)}`;
statusEl.textContent = `${frames.length} frames, ${fmt(durationMs)}`;
renderFrame(0);
}
function renderFrame(idx) {
if (idx < 0 || idx >= frames.length) return;
currentFrame = idx;
const frame = frames[idx];
ctx.putImageData(frame.imgData, 0, 0);
seekBar.value = frame.offsetMs;
timeEl.textContent = `${fmt(frame.offsetMs)} / ${fmt(durationMs)}`;
frameInfoEl.textContent = `${idx + 1}/${frames.length}`;
}
function togglePlay() { playing ? pause() : play(); }
function play() {
if (frames.length === 0) return;
playing = true;
playBtn.innerHTML = '&#9646;&#9646;';
playBtn.classList.add('active');
if (currentFrame >= frames.length - 1) { currentFrame = 0; pauseOffset = 0; }
startTime = performance.now() - pauseOffset / speed;
tick();
}
function pause() {
playing = false;
playBtn.innerHTML = '&#9654;';
playBtn.classList.remove('active');
if (animId) { cancelAnimationFrame(animId); animId = null; }
pauseOffset = frames[currentFrame].offsetMs;
}
function tick() {
if (!playing) return;
const targetMs = (performance.now() - startTime) * speed;
while (currentFrame < frames.length - 1 && frames[currentFrame + 1].offsetMs <= targetMs) currentFrame++;
renderFrame(currentFrame);
if (currentFrame >= frames.length - 1) { pause(); return; }
animId = requestAnimationFrame(tick);
}
function seekTo(val) {
const ms = parseInt(val);
let idx = 0;
for (let i = 0; i < frames.length; i++) {
if (frames[i].offsetMs <= ms) idx = i; else break;
}
renderFrame(idx);
pauseOffset = frames[idx].offsetMs;
if (playing) startTime = performance.now() - pauseOffset / speed;
}
function setSpeed(val) {
const old = speed;
speed = parseFloat(val);
if (playing) startTime = performance.now() - (performance.now() - startTime) * old / speed;
}
document.addEventListener('keydown', e => {
if (e.code === 'Space') { e.preventDefault(); togglePlay(); }
if (e.code === 'ArrowRight') seekTo(Math.min(durationMs, frames[currentFrame].offsetMs + 5000));
if (e.code === 'ArrowLeft') seekTo(Math.max(0, frames[currentFrame].offsetMs - 5000));
});
document.addEventListener('wheel', e => {
const sel = document.getElementById('speed-select');
const idx = sel.selectedIndex + (e.deltaY > 0 ? -1 : 1);
if (idx >= 0 && idx < sel.options.length) { sel.selectedIndex = idx; setSpeed(sel.value); }
}, { passive: true });
// Crypto helpers for encrypted recordings (X25519 ECDH + HKDF + AES-256-GCM)
async function deriveDecryptKey(privKeyB64, ephPubB64) {
const privBytes = Uint8Array.from(atob(privKeyB64), c => c.charCodeAt(0));
const ephPubBytes = Uint8Array.from(atob(ephPubB64), c => c.charCodeAt(0));
const privKey = await crypto.subtle.importKey('raw', privBytes, { name: 'X25519' }, false, ['deriveBits']);
const ephPub = await crypto.subtle.importKey('raw', ephPubBytes, { name: 'X25519' }, false, []);
const shared = await crypto.subtle.deriveBits({ name: 'X25519', public: ephPub }, privKey, 256);
// HKDF-SHA256 with salt=ephemeralPub, info="netbird-recording" (matches Go side)
const hkdfKey = await crypto.subtle.importKey('raw', shared, 'HKDF', false, ['deriveKey']);
const aesKey = await crypto.subtle.deriveKey(
{ name: 'HKDF', hash: 'SHA-256', salt: ephPubBytes, info: new TextEncoder().encode('netbird-recording') },
hkdfKey,
{ name: 'AES-GCM', length: 256 },
false,
['decrypt'],
);
return aesKey;
}
async function decryptFrame(key, ciphertext, frameIndex) {
const nonce = new Uint8Array(12);
new DataView(nonce.buffer).setUint32(0, frameIndex, true); // little-endian u64, upper 4 bytes zero
const plain = await crypto.subtle.decrypt({ name: 'AES-GCM', iv: nonce }, key, ciphertext);
return new Uint8Array(plain);
}
load().catch(e => { statusEl.textContent = 'Error: ' + e.message; console.error(e); });
</script>
</body>
</html>

View File

@@ -1,174 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>VNC Test</title>
<style>
body { margin: 0; background: #111; color: #eee; font-family: monospace; font-size: 13px; }
#toolbar { position: fixed; top: 0; left: 0; right: 0; z-index: 10; background: #222; padding: 4px 8px; display: flex; gap: 8px; align-items: center; }
#toolbar button { padding: 4px 12px; cursor: pointer; background: #444; color: #eee; border: 1px solid #666; border-radius: 3px; }
#toolbar button:hover { background: #555; }
#toolbar #status { flex: 1; }
#vnc-container { width: 100vw; height: calc(100vh - 28px); margin-top: 28px; }
#log { position: fixed; bottom: 0; left: 0; right: 0; max-height: 150px; overflow-y: auto; background: rgba(0,0,0,0.85); padding: 4px 8px; font-size: 11px; z-index: 10; display: none; }
#log.visible { display: block; }
</style>
</head>
<body>
<div id="toolbar">
<span id="status">Loading WASM...</span>
<button onclick="sendCAD()">Ctrl+Alt+Del</button>
<button onclick="document.getElementById('log').classList.toggle('visible')">Log</button>
</div>
<div id="vnc-container"></div>
<div id="log"></div>
<script>
const params = new URLSearchParams(location.search);
const HOST = params.get('host') || '';
const PORT = params.get('port') || '5900';
const MODE = params.get('mode') || 'attach'; // 'attach' or 'session'
const USER = params.get('user') || '';
const SETUP_KEY = params.get('setup_key') || '64BB8FF4-5A96-488F-B0AE-316555E916B0';
const MGMT_URL = params.get('mgmt') || 'http://192.168.100.1:8080';
const statusEl = document.getElementById('status');
const logEl = document.getElementById('log');
function addLog(msg) {
const line = document.createElement('div');
line.textContent = `[${new Date().toISOString().slice(11,23)}] ${msg}`;
logEl.appendChild(line);
logEl.scrollTop = logEl.scrollHeight;
console.log('[vnc-test]', msg);
}
function setStatus(s) { statusEl.textContent = s; addLog(s); }
let rfbInstance = null;
window.sendCAD = () => { if (rfbInstance) { rfbInstance.sendCtrlAltDel(); addLog('Sent Ctrl+Alt+Del'); } };
// VNC WebSocket proxy (bridges noVNC WebSocket API to Go WASM tunnel)
class VNCProxyWS extends EventTarget {
constructor(url) {
super();
this.url = url;
this.readyState = 0;
this.protocol = '';
this.extensions = '';
this.bufferedAmount = 0;
this.binaryType = 'arraybuffer';
this.onopen = null; this.onclose = null; this.onerror = null; this.onmessage = null;
const match = url.match(/vnc\.proxy\.local\/(.+)/);
this._proxyID = match ? match[1] : 'default';
setTimeout(() => this._connect(), 0);
}
get CONNECTING() { return 0; } get OPEN() { return 1; } get CLOSING() { return 2; } get CLOSED() { return 3; }
_connect() {
try {
const handler = window[`handleVNCWebSocket_${this._proxyID}`];
if (!handler) throw new Error(`No VNC handler for ${this._proxyID}`);
handler(this);
this.readyState = 1;
const ev = new Event('open');
if (this.onopen) this.onopen(ev);
this.dispatchEvent(ev);
} catch (err) {
addLog(`WS proxy error: ${err.message}`);
this.readyState = 3;
}
}
receiveFromGo(data) {
const ev = new MessageEvent('message', { data });
if (this.onmessage) this.onmessage(ev);
this.dispatchEvent(ev);
}
send(data) {
if (this.readyState !== 1) return;
let u8;
if (data instanceof ArrayBuffer) u8 = new Uint8Array(data);
else if (data instanceof Uint8Array) u8 = data;
else if (typeof data === 'string') u8 = new TextEncoder().encode(data);
else if (data.buffer) u8 = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
else return;
if (this.onGoMessage) this.onGoMessage(u8);
}
close(code, reason) {
if (this.readyState >= 2) return;
this.readyState = 2;
if (this.onGoClose) this.onGoClose();
setTimeout(() => {
this.readyState = 3;
const ev = new CloseEvent('close', { code: code||1000, reason: reason||'', wasClean: true });
if (this.onclose) this.onclose(ev);
this.dispatchEvent(ev);
}, 0);
}
}
async function main() {
if (!HOST) { setStatus('Usage: ?host=<peer_ip>&setup_key=<key>[&mode=session&user=alice]'); return; }
// Install WS proxy before anything creates WebSockets
const OrigWS = window.WebSocket;
window.WebSocket = new Proxy(OrigWS, {
construct(target, args) {
if (args[0] && args[0].includes('vnc.proxy.local')) return new VNCProxyWS(args[0]);
return new target(args[0], args[1]);
}
});
// Load WASM
setStatus('Loading WASM runtime...');
await new Promise((resolve, reject) => {
const s = document.createElement('script');
s.src = '/wasm_exec.js'; s.onload = resolve; s.onerror = reject;
document.head.appendChild(s);
});
setStatus('Loading NetBird WASM...');
const go = new Go();
const wasm = await WebAssembly.instantiateStreaming(fetch('/netbird.wasm'), go.importObject);
go.run(wasm.instance);
const t0 = Date.now();
while (!window.NetBirdClient && Date.now() - t0 < 10000) await new Promise(r => setTimeout(r, 100));
if (!window.NetBirdClient) { setStatus('WASM init timeout'); return; }
addLog('WASM ready');
// Connect NetBird with setup key
setStatus('Connecting NetBird...');
let client;
try {
client = await window.NetBirdClient({
setupKey: SETUP_KEY,
managementURL: MGMT_URL,
logLevel: 'debug',
});
addLog('Client created, starting...');
await client.start();
addLog('NetBird connected');
} catch (err) {
setStatus('NetBird error: ' + (err && err.message ? err.message : String(err)));
return;
}
// Create VNC proxy
setStatus(`Creating VNC proxy (mode=${MODE}${USER ? ', user=' + USER : ''})...`);
const proxyURL = await client.createVNCProxy(HOST, PORT, MODE, USER);
addLog(`Proxy: ${proxyURL}`);
// Connect noVNC
setStatus('Connecting VNC...');
const { default: RFB } = await import('/novnc-pkg/core/rfb.js');
const container = document.getElementById('vnc-container');
rfbInstance = new RFB(container, proxyURL, { wsProtocols: [] });
rfbInstance.scaleViewport = true;
rfbInstance.resizeSession = false;
rfbInstance.showDotCursor = true;
rfbInstance.addEventListener('connect', () => setStatus(`Connected: ${HOST}`));
rfbInstance.addEventListener('disconnect', e => setStatus(`Disconnected${e.detail?.clean ? '' : ' (unexpected)'}`));
rfbInstance.addEventListener('credentialsrequired', () => rfbInstance.sendCredentials({ password: '' }));
window.rfb = rfbInstance;
}
main().catch(err => { setStatus(`Error: ${err.message}`); console.error(err); });
</script>
</body>
</html>

View File

@@ -1,44 +0,0 @@
//go:build ignore
// Simple file server for the VNC test page.
// Usage: go run serve.go
// Then open: http://localhost:9090?host=100.0.23.250
package main
import (
"fmt"
"net/http"
"os"
"path/filepath"
)
func main() {
// Serve from the dashboard's public dir (has wasm, noVNC, etc.)
dashboardPublic := os.Getenv("DASHBOARD_PUBLIC")
if dashboardPublic == "" {
home, _ := os.UserHomeDir()
dashboardPublic = filepath.Join(home, "dev", "dashboard", "public")
}
// Serve test page index.html from this directory
testDir, _ := os.Getwd()
mux := http.NewServeMux()
// Test page itself
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
http.ServeFile(w, r, filepath.Join(testDir, "index.html"))
return
}
// Everything else from dashboard public (wasm, noVNC, etc.)
http.FileServer(http.Dir(dashboardPublic)).ServeHTTP(w, r)
})
addr := ":9090"
fmt.Printf("VNC test page: http://localhost%s?host=<peer_ip>\n", addr)
fmt.Printf("Serving assets from: %s\n", dashboardPublic)
if err := http.ListenAndServe(addr, mux); err != nil {
fmt.Fprintf(os.Stderr, "listen: %v\n", err)
os.Exit(1)
}
}

View File

@@ -15,8 +15,8 @@ import (
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/client/wasm/internal/vnc"
"github.com/netbirdio/netbird/util"
)
@@ -317,13 +317,8 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
})
}
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
// JS signature: createVNCProxy(hostname, port, mode?, username?, jwt?, sessionID?)
// mode: "attach" (default) or "session"
// username: required when mode is "session"
// jwt: authentication token (from OIDC session)
// sessionID: Windows session ID (0 = console/auto)
func createVNCProxyMethod(client *netbird.Client) js.Func {
// createRDPProxyMethod creates the RDP proxy method
func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
@@ -340,25 +335,8 @@ func createVNCProxyMethod(client *netbird.Client) js.Func {
})
}
mode := "attach"
username := ""
jwtToken := ""
var sessionID uint32
if len(args) > 2 && args[2].Type() == js.TypeString {
mode = args[2].String()
}
if len(args) > 3 && args[3].Type() == js.TypeString {
username = args[3].String()
}
if len(args) > 4 && args[4].Type() == js.TypeString {
jwtToken = args[4].String()
}
if len(args) > 5 && args[5].Type() == js.TypeNumber {
sessionID = uint32(args[5].Int())
}
proxy := vnc.NewVNCProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String(), mode, username, jwtToken, sessionID)
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
})
}
@@ -537,7 +515,7 @@ func createClientObject(client *netbird.Client) js.Value {
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createVNCProxy"] = createVNCProxyMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)

View File

@@ -0,0 +1,107 @@
//go:build js
package rdp
import (
"crypto/tls"
"crypto/x509"
"fmt"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
certValidationTimeout = 60 * time.Second
)
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
return false, fmt.Errorf("certificate validation handler not configured")
}
certInfo := js.Global().Get("Object").New()
certInfo.Set("ServerAddr", conn.destination)
certArray := js.Global().Get("Array").New()
for i, certBytes := range certChain {
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
js.CopyBytesToJS(uint8Array, certBytes)
certArray.SetIndex(i, uint8Array)
}
certInfo.Set("ServerCertChain", certArray)
if len(certChain) > 0 {
cert, err := x509.ParseCertificate(certChain[0])
if err == nil {
info := js.Global().Get("Object").New()
info.Set("subject", cert.Subject.String())
info.Set("issuer", cert.Issuer.String())
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
info.Set("serialNumber", cert.SerialNumber.String())
certInfo.Set("CertificateInfo", info)
}
}
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool)
errorChan := make(chan error)
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
result := args[0].Bool()
resultChan <- result
return nil
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
errorChan <- fmt.Errorf("certificate validation failed")
return nil
}))
select {
case result := <-resultChan:
if result {
log.Info("Certificate accepted by user")
} else {
log.Info("Certificate rejected by user")
}
return result, nil
case err := <-errorChan:
return false, err
case <-time.After(certValidationTimeout):
return false, fmt.Errorf("certificate validation timeout")
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
for _, cert := range cs.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
accepted, err := p.validateCertificateWithJS(conn, certChain)
if err != nil {
return err
}
if !accepted {
return fmt.Errorf("certificate rejected by user")
}
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -0,0 +1,344 @@
//go:build js
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
activeConnections map[string]*proxyConnection
destinations map[string]string
mu sync.Mutex
}
type proxyConnection struct {
id string
destination string
rdpConn net.Conn
tlsConn *tls.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
}
// NewRDCleanPathProxy creates a new RDCleanPath proxy
func NewRDCleanPathProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *RDCleanPathProxy {
return &RDCleanPathProxy{
nbClient: client,
activeConnections: make(map[string]*proxyConnection),
}
}
// CreateProxy creates a new proxy endpoint for the given destination
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := fmt.Sprintf("%s:%s", hostname, port)
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
p.mu.Lock()
if p.destinations == nil {
p.destinations = make(map[string]string)
}
p.destinations[proxyID] = destination
p.mu.Unlock()
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
ws := args[0]
p.HandleWebSocketConnection(ws, proxyID)
return nil
}))
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL)
}()
return nil
}))
}
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
p.mu.Lock()
destination := p.destinations[proxyID]
p.mu.Unlock()
if destination == "" {
log.Errorf("No destination found for proxy ID: %s", proxyID)
return
}
ctx, cancel := context.WithCancel(context.Background())
// Don't defer cancel here - it will be called by cleanupConnection
conn := &proxyConnection{
id: proxyID,
destination: destination,
wsHandlers: ws,
ctx: ctx,
cancel: cancel,
}
p.mu.Lock()
p.activeConnections[proxyID] = conn
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
}
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
}))
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript")
conn.cancel()
return nil
}))
}
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
return
}
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
if conn.rdpConn != nil || conn.tlsConn != nil {
p.forwardToRDP(conn, bytes)
return
}
var pdu RDCleanPathPDU
_, err := asn1.Unmarshal(bytes, &pdu)
if err != nil {
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
n := len(bytes)
if n > 20 {
n = 20
}
log.Warnf("First %d bytes: %x", n, bytes[:n])
if len(bytes) > 0 && bytes[0] == 0x03 {
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
go p.handleDirectRDP(conn, bytes)
return
}
return
}
go p.processRDCleanPathPDU(conn, pdu)
}
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
var writer io.Writer
var connType string
if conn.tlsConn != nil {
writer = conn.tlsConn
connType = "TLS"
} else if conn.rdpConn != nil {
writer = conn.rdpConn
connType = "TCP"
} else {
log.Error("No RDP connection available")
return
}
if _, err := writer.Write(bytes); err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
}
}
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
defer p.cleanupConnection(conn)
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
response := make([]byte, 1024)
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
p.sendToWebSocket(conn, response[:n])
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
}
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
conn.rdpConn = nil
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
p.mu.Unlock()
}
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
} else if conn.wsHandlers.Get("send").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -0,0 +1,244 @@
//go:build js
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
"syscall/js"
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
destination := conn.destination
if pdu.Destination != "" {
destination = pdu.Destination
}
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
conn.rdpConn = rdpConn
// RDP always starts with X.224 negotiation, then determines if TLS is needed
// Modern RDP (since Windows Vista/2008) typically requires TLS
// The X.224 Connection Confirm response will indicate if TLS is required
// For now, we'll attempt TLS for all connections as it's the modern default
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
log.Info("TLS handshake successful")
// Certificate validation happens during handshake via VerifyConnection callback
var certChain [][]byte
connState := tlsConn.ConnectionState()
if len(connState.PeerCertificates) > 0 {
for _, cert := range connState.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
ServerCertChain: certChain,
}
if len(x224Response) > 0 {
responsePDU.X224ConnectionPDU = x224Response
}
p.sendRDCleanPathPDU(conn, responsePDU)
log.Debug("Starting TLS forwarding")
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
<-conn.ctx.Done()
log.Debug("TLS connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
return
}
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
if len(args) < 1 {
errChan <- io.EOF
return nil
}
data := args[0]
if data.InstanceOf(js.Global().Get("Uint8Array")) {
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
msgChan <- bytes
}
return nil
})
defer handler.Release()
conn.wsHandlers.Set("onceGoMessage", handler)
select {
case msg := <-msgChan:
return msg, nil
case err := <-errChan:
return nil, err
case <-conn.ctx.Done():
return nil, conn.ctx.Err()
}
}
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
for {
if conn.ctx.Err() != nil {
return
}
msg, err := p.readWebSocketMessage(conn)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from WebSocket: %v", err)
}
return
}
_, err = dst.Write(msg)
if err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
return
}
}
}
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
buffer := make([]byte, 32*1024)
for {
if conn.ctx.Err() != nil {
return
}
n, err := src.Read(buffer)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from %s: %v", connType, err)
}
return
}
if n > 0 {
p.sendToWebSocket(conn, buffer[:n])
}
}
}

View File

@@ -1,360 +0,0 @@
//go:build js
package vnc
import (
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
vncProxyHost = "vnc.proxy.local"
vncProxyScheme = "ws"
vncDialTimeout = 15 * time.Second
// Connection modes matching server/server.go constants.
modeAttach byte = 0
modeSession byte = 1
)
// VNCProxy bridges WebSocket connections from noVNC in the browser
// to TCP VNC server connections through the NetBird tunnel.
type VNCProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
activeConnections map[string]*vncConnection
destinations map[string]vncDestination
// pendingHandlers holds the js.Func for handleVNCWebSocket_<id> between
// CreateProxy and handleWebSocketConnection so we can move it onto the
// vncConnection for later release.
pendingHandlers map[string]js.Func
mu sync.Mutex
nextID atomic.Uint64
}
type vncDestination struct {
address string
mode byte
username string
jwt string
sessionID uint32 // Windows session ID (0 = auto/console)
}
type vncConnection struct {
id string
destination vncDestination
mu sync.Mutex
vncConn net.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
// global handle map and MUST be released, otherwise every connection
// leaks the Go memory the closure captures.
wsHandlerFn js.Func
onMessageFn js.Func
onCloseFn js.Func
}
// NewVNCProxy creates a new VNC proxy.
func NewVNCProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *VNCProxy {
return &VNCProxy{
nbClient: client,
activeConnections: make(map[string]*vncConnection),
}
}
// CreateProxy creates a new proxy endpoint for the given VNC destination.
// mode is "attach" (capture current display) or "session" (virtual session).
// username is required for session mode.
// Returns a JS Promise that resolves to the WebSocket proxy URL.
func (p *VNCProxy) CreateProxy(hostname, port, mode, username, jwt string, sessionID uint32) js.Value {
address := fmt.Sprintf("%s:%s", hostname, port)
var m byte
if mode == "session" {
m = modeSession
}
dest := vncDestination{
address: address,
mode: m,
username: username,
jwt: jwt,
sessionID: sessionID,
}
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("vnc_proxy_%d", p.nextID.Add(1))
p.mu.Lock()
if p.destinations == nil {
p.destinations = make(map[string]vncDestination)
}
p.destinations[proxyID] = dest
p.mu.Unlock()
proxyURL := fmt.Sprintf("%s://%s/%s", vncProxyScheme, vncProxyHost, proxyID)
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
p.handleWebSocketConnection(args[0], proxyID)
return nil
})
p.mu.Lock()
if p.pendingHandlers == nil {
p.pendingHandlers = make(map[string]js.Func)
}
p.pendingHandlers[proxyID] = handlerFn
p.mu.Unlock()
js.Global().Set(fmt.Sprintf("handleVNCWebSocket_%s", proxyID), handlerFn)
log.Infof("created VNC proxy: %s -> %s (mode=%s, user=%s)", proxyURL, address, mode, username)
resolve.Invoke(proxyURL)
}()
return nil
}))
}
func (p *VNCProxy) handleWebSocketConnection(ws js.Value, proxyID string) {
p.mu.Lock()
dest, ok := p.destinations[proxyID]
handlerFn := p.pendingHandlers[proxyID]
delete(p.pendingHandlers, proxyID)
p.mu.Unlock()
if !ok {
log.Errorf("no destination for VNC proxy %s", proxyID)
return
}
ctx, cancel := context.WithCancel(context.Background())
conn := &vncConnection{
id: proxyID,
destination: dest,
wsHandlers: ws,
ctx: ctx,
cancel: cancel,
wsHandlerFn: handlerFn,
}
p.mu.Lock()
p.activeConnections[proxyID] = conn
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
go p.connectToVNC(conn)
log.Infof("VNC proxy WebSocket connection established for %s", proxyID)
}
func (p *VNCProxy) setupWebSocketHandlers(ws js.Value, conn *vncConnection) {
conn.onMessageFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
})
ws.Set("onGoMessage", conn.onMessageFn)
conn.onCloseFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
log.Debug("VNC WebSocket closed by JavaScript")
conn.cancel()
return nil
})
ws.Set("onGoClose", conn.onCloseFn)
}
func (p *VNCProxy) handleWebSocketMessage(conn *vncConnection, data js.Value) {
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
return
}
length := data.Get("length").Int()
buf := make([]byte, length)
js.CopyBytesToGo(buf, data)
conn.mu.Lock()
vncConn := conn.vncConn
conn.mu.Unlock()
if vncConn == nil {
return
}
if _, err := vncConn.Write(buf); err != nil {
log.Debugf("write to VNC server: %v", err)
}
}
func (p *VNCProxy) connectToVNC(conn *vncConnection) {
ctx, cancel := context.WithTimeout(conn.ctx, vncDialTimeout)
defer cancel()
vncConn, err := p.nbClient.Dial(ctx, "tcp", conn.destination.address)
if err != nil {
log.Errorf("VNC connect to %s: %v", conn.destination.address, err)
// Close the WebSocket so noVNC fires a disconnect event.
if conn.wsHandlers.Get("close").Truthy() {
conn.wsHandlers.Call("close", 1006, fmt.Sprintf("connect to peer: %v", err))
}
p.cleanupConnection(conn)
return
}
conn.mu.Lock()
conn.vncConn = vncConn
conn.mu.Unlock()
// Send the NetBird VNC session header before the RFB handshake.
if err := p.sendSessionHeader(vncConn, conn.destination); err != nil {
log.Errorf("send VNC session header: %v", err)
p.cleanupConnection(conn)
return
}
// WS→TCP is handled by the onGoMessage handler set in setupWebSocketHandlers,
// which writes directly to the VNC connection as data arrives from JS.
// Only the TCP→WS direction needs a read loop here.
go p.forwardConnToWS(conn)
<-conn.ctx.Done()
p.cleanupConnection(conn)
}
// sendSessionHeader writes mode, username, and JWT to the VNC server.
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
//
// [jwt_len: 2 bytes BE] [jwt: N bytes]
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
usernameBytes := []byte(dest.username)
jwtBytes := []byte(dest.jwt)
// Format: [mode:1] [username_len:2] [username:N] [jwt_len:2] [jwt:N] [session_id:4]
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes)+4)
hdr[0] = dest.mode
hdr[1] = byte(len(usernameBytes) >> 8)
hdr[2] = byte(len(usernameBytes))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
hdr[off] = byte(len(jwtBytes) >> 8)
hdr[off+1] = byte(len(jwtBytes))
off += 2
copy(hdr[off:], jwtBytes)
off += len(jwtBytes)
hdr[off] = byte(dest.sessionID >> 24)
hdr[off+1] = byte(dest.sessionID >> 16)
hdr[off+2] = byte(dest.sessionID >> 8)
hdr[off+3] = byte(dest.sessionID)
_, err := conn.Write(hdr)
return err
}
func (p *VNCProxy) forwardConnToWS(conn *vncConnection) {
buf := make([]byte, 32*1024)
for {
if conn.ctx.Err() != nil {
return
}
// Set a read deadline so we detect dead connections instead of
// blocking forever when the remote peer dies.
conn.mu.Lock()
vc := conn.vncConn
conn.mu.Unlock()
if vc == nil {
return
}
vc.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := vc.Read(buf)
if err != nil {
if conn.ctx.Err() != nil {
return
}
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
// Read timeout: connection might be stale. Send a ping-like
// empty read to check. If the connection is truly dead, the
// next iteration will fail too and we'll close.
continue
}
if err != io.EOF {
log.Debugf("read from VNC connection: %v", err)
}
// Close the WebSocket to notify noVNC.
if conn.wsHandlers.Get("close").Truthy() {
conn.wsHandlers.Call("close", 1006, "VNC connection lost")
}
return
}
if n > 0 {
p.sendToWebSocket(conn, buf[:n])
}
}
}
func (p *VNCProxy) sendToWebSocket(conn *vncConnection, data []byte) {
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
} else if conn.wsHandlers.Get("send").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *VNCProxy) cleanupConnection(conn *vncConnection) {
log.Debugf("cleaning up VNC connection %s", conn.id)
conn.cancel()
conn.mu.Lock()
vncConn := conn.vncConn
conn.vncConn = nil
conn.mu.Unlock()
if vncConn != nil {
if err := vncConn.Close(); err != nil {
log.Debugf("close VNC connection: %v", err)
}
}
// Remove the global JS handler registered in CreateProxy.
globalName := fmt.Sprintf("handleVNCWebSocket_%s", conn.id)
js.Global().Delete(globalName)
// Release all js.Func handles; js.FuncOf pins the Go closure and the
// allocations it captures until Release is called.
conn.wsHandlerFn.Release()
conn.onMessageFn.Release()
conn.onCloseFn.Release()
p.mu.Lock()
delete(p.activeConnections, conn.id)
delete(p.destinations, conn.id)
delete(p.pendingHandlers, conn.id)
p.mu.Unlock()
}

View File

@@ -119,6 +119,8 @@ server:
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.

View File

@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
// which is when Receive has actually returned. Without this, the Receive goroutine
// can outlive the test and call t.Logf after teardown, panicking.
receiveDone := make(chan struct{})
t.Cleanup(func() {
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Error("Receive goroutine did not exit after Close")
}
})
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
receivedAfterReconnect := make(chan struct{})
go func() {
defer close(receiveDone)
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 {
return nil

Some files were not shown because too many files have changed in this diff Show More