mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 02:22:25 -04:00
Compare commits
14 Commits
vnc-server
...
github-iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b8e40f78d | ||
|
|
5da05ecca6 | ||
|
|
801de8c68d | ||
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c | ||
|
|
5a89e6621b | ||
|
|
06dfa9d4a5 | ||
|
|
45d9ee52c0 |
26
.github/issue-resolution/prompts/issue-resolution-system.txt
vendored
Normal file
26
.github/issue-resolution/prompts/issue-resolution-system.txt
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
You are a GitHub issue resolution classifier.
|
||||
|
||||
Your job is to decide whether an open GitHub issue is:
|
||||
- AUTO_CLOSE
|
||||
- MANUAL_REVIEW
|
||||
- KEEP_OPEN
|
||||
|
||||
Rules:
|
||||
1. AUTO_CLOSE is only allowed if there is objective, hard evidence:
|
||||
- a merged linked PR that clearly resolves the issue, or
|
||||
- an explicit maintainer/member/owner/collaborator comment saying the issue is fixed, resolved, duplicate, or superseded
|
||||
2. If there is any contradictory later evidence, do NOT AUTO_CLOSE.
|
||||
3. If evidence is promising but not airtight, choose MANUAL_REVIEW.
|
||||
4. If the issue still appears active or unresolved, choose KEEP_OPEN.
|
||||
5. Do not invent evidence.
|
||||
6. Output valid JSON only.
|
||||
|
||||
Maintainer-authoritative roles:
|
||||
- MEMBER
|
||||
- OWNER
|
||||
- COLLABORATOR
|
||||
|
||||
Important:
|
||||
- Later comments outweigh earlier ones.
|
||||
- A non-maintainer saying "fixed for me" is not enough for AUTO_CLOSE.
|
||||
- If uncertain, prefer MANUAL_REVIEW or KEEP_OPEN.
|
||||
78
.github/issue-resolution/schemas/issue-resolution-output.json
vendored
Normal file
78
.github/issue-resolution/schemas/issue-resolution-output.json
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"decision",
|
||||
"reason_code",
|
||||
"confidence",
|
||||
"hard_signals",
|
||||
"contradictions",
|
||||
"summary",
|
||||
"close_comment",
|
||||
"manual_review_note"
|
||||
],
|
||||
"properties": {
|
||||
"decision": {
|
||||
"type": "string",
|
||||
"enum": ["AUTO_CLOSE", "MANUAL_REVIEW", "KEEP_OPEN"]
|
||||
},
|
||||
"reason_code": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"resolved_by_merged_pr",
|
||||
"maintainer_confirmed_resolved",
|
||||
"duplicate_confirmed",
|
||||
"superseded_confirmed",
|
||||
"likely_fixed_but_unconfirmed",
|
||||
"still_open",
|
||||
"unclear"
|
||||
]
|
||||
},
|
||||
"confidence": {
|
||||
"type": "number",
|
||||
"minimum": 0,
|
||||
"maximum": 1
|
||||
},
|
||||
"hard_signals": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["type", "url"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"merged_pr",
|
||||
"maintainer_comment",
|
||||
"duplicate_reference",
|
||||
"superseded_reference"
|
||||
]
|
||||
},
|
||||
"url": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"contradictions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"required": ["type", "url"],
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"reporter_still_broken",
|
||||
"later_unresolved_comment",
|
||||
"ambiguous_pr_link",
|
||||
"other"
|
||||
]
|
||||
},
|
||||
"url": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"summary": { "type": "string" },
|
||||
"close_comment": { "type": "string" },
|
||||
"manual_review_note": { "type": "string" }
|
||||
}
|
||||
}
|
||||
152
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
Normal file
152
.github/issue-resolution/scripts/apply-decisions.mjs
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const decisions = JSON.parse(await fs.readFile("decisions.json", "utf8"));
|
||||
const dryRun = String(process.env.DRY_RUN).toLowerCase() === "true";
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${process.env.GH_TOKEN}`,
|
||||
Accept: "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
};
|
||||
|
||||
async function rest(url, method = "GET", body) {
|
||||
const res = await fetch(url, {
|
||||
method,
|
||||
headers,
|
||||
body: body ? JSON.stringify(body) : undefined
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status} ${url}: ${await res.text()}`);
|
||||
return res.status === 204 ? null : res.json();
|
||||
}
|
||||
|
||||
async function graphql(query, variables) {
|
||||
const res = await fetch("https://api.github.com/graphql", {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({ query, variables })
|
||||
});
|
||||
if (!res.ok) throw new Error(`${res.status}: ${await res.text()}`);
|
||||
const json = await res.json();
|
||||
if (json.errors) throw new Error(JSON.stringify(json.errors));
|
||||
return json.data;
|
||||
}
|
||||
|
||||
async function addLabel(owner, repo, issueNumber, labels) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/labels`,
|
||||
"POST",
|
||||
{ labels }
|
||||
);
|
||||
}
|
||||
|
||||
async function addComment(owner, repo, issueNumber, body) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}/comments`,
|
||||
"POST",
|
||||
{ body }
|
||||
);
|
||||
}
|
||||
|
||||
async function closeIssue(owner, repo, issueNumber) {
|
||||
return rest(
|
||||
`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`,
|
||||
"PATCH",
|
||||
{ state: "closed", state_reason: "completed" }
|
||||
);
|
||||
}
|
||||
|
||||
async function getIssueNodeId(owner, repo, issueNumber) {
|
||||
const issue = await rest(`https://api.github.com/repos/${owner}/${repo}/issues/${issueNumber}`);
|
||||
return issue.node_id;
|
||||
}
|
||||
|
||||
async function addToProject(issueNodeId) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $contentId: ID!) {
|
||||
addProjectV2ItemById(input: {projectId: $projectId, contentId: $contentId}) {
|
||||
item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
const data = await graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
contentId: issueNodeId
|
||||
});
|
||||
|
||||
return data.addProjectV2ItemById.item.id;
|
||||
}
|
||||
|
||||
async function setTextField(itemId, fieldId, value) {
|
||||
const mutation = `
|
||||
mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
|
||||
updateProjectV2ItemFieldValue(input: {
|
||||
projectId: $projectId,
|
||||
itemId: $itemId,
|
||||
fieldId: $fieldId,
|
||||
value: { text: $value }
|
||||
}) {
|
||||
projectV2Item { id }
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
return graphql(mutation, {
|
||||
projectId: process.env.PROJECT_ID,
|
||||
itemId,
|
||||
fieldId,
|
||||
value
|
||||
});
|
||||
}
|
||||
|
||||
for (const d of decisions) {
|
||||
const [owner, repo] = d.repository.split("/");
|
||||
|
||||
if (d.final_decision === "AUTO_CLOSE") {
|
||||
if (dryRun) continue;
|
||||
|
||||
await addLabel(owner, repo, d.issue_number, ["auto-closed-resolved"]);
|
||||
await addComment(
|
||||
owner,
|
||||
repo,
|
||||
d.issue_number,
|
||||
d.model.close_comment ||
|
||||
"This appears resolved based on linked evidence, so we’re closing it automatically. Reply if this still reproduces and we’ll reopen."
|
||||
);
|
||||
await closeIssue(owner, repo, d.issue_number);
|
||||
}
|
||||
|
||||
if (d.final_decision === "MANUAL_REVIEW") {
|
||||
await addLabel(owner, repo, d.issue_number, ["resolution-candidate"]);
|
||||
|
||||
const issueNodeId = await getIssueNodeId(owner, repo, d.issue_number);
|
||||
const itemId = await addToProject(issueNodeId);
|
||||
|
||||
if (process.env.PROJECT_CONFIDENCE_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_CONFIDENCE_FIELD_ID, String(d.model.confidence));
|
||||
}
|
||||
if (process.env.PROJECT_REASON_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_REASON_FIELD_ID, d.model.reason_code);
|
||||
}
|
||||
if (process.env.PROJECT_EVIDENCE_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_EVIDENCE_FIELD_ID, d.issue_url);
|
||||
}
|
||||
if (process.env.PROJECT_LINKED_PR_FIELD_ID) {
|
||||
const linked = (d.model.hard_signals || []).map(x => x.url).join(", ");
|
||||
if (linked) {
|
||||
await setTextField(itemId, process.env.PROJECT_LINKED_PR_FIELD_ID, linked);
|
||||
}
|
||||
}
|
||||
if (process.env.PROJECT_REPO_FIELD_ID) {
|
||||
await setTextField(itemId, process.env.PROJECT_REPO_FIELD_ID, d.repository);
|
||||
}
|
||||
|
||||
await addComment(
|
||||
owner,
|
||||
repo,
|
||||
d.issue_number,
|
||||
d.model.manual_review_note ||
|
||||
"This issue looks like a possible resolution candidate, but not with enough certainty for automatic closure. Added to the review queue."
|
||||
);
|
||||
}
|
||||
}
|
||||
125
.github/issue-resolution/scripts/classify-candidates.mjs
vendored
Normal file
125
.github/issue-resolution/scripts/classify-candidates.mjs
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
import fs from "node:fs/promises";
|
||||
|
||||
const candidates = JSON.parse(await fs.readFile("candidates.json", "utf8"));
|
||||
|
||||
function isMaintainerRole(role) {
|
||||
return ["MEMBER", "OWNER", "COLLABORATOR"].includes(role || "");
|
||||
}
|
||||
|
||||
function preScore(candidate) {
|
||||
let score = 0;
|
||||
const hardSignals = [];
|
||||
const contradictions = [];
|
||||
|
||||
for (const t of candidate.timeline) {
|
||||
const sourceIssue = t.source?.issue;
|
||||
|
||||
if (t.event === "cross-referenced" && sourceIssue?.pull_request?.html_url) {
|
||||
hardSignals.push({
|
||||
type: "merged_pr",
|
||||
url: sourceIssue.html_url
|
||||
});
|
||||
score += 40; // provisional until PR merged state is verified
|
||||
}
|
||||
|
||||
if (["referenced", "connected"].includes(t.event)) {
|
||||
score += 10;
|
||||
}
|
||||
}
|
||||
|
||||
for (const c of candidate.comments) {
|
||||
const body = c.body.toLowerCase();
|
||||
|
||||
if (
|
||||
isMaintainerRole(c.author_association) &&
|
||||
/\b(fixed|resolved|duplicate|superseded|closing)\b/.test(body)
|
||||
) {
|
||||
score += 25;
|
||||
hardSignals.push({
|
||||
type: "maintainer_comment",
|
||||
url: c.html_url
|
||||
});
|
||||
}
|
||||
|
||||
if (/\b(still broken|still happening|not fixed|reproducible)\b/.test(body)) {
|
||||
score -= 50;
|
||||
contradictions.push({
|
||||
type: "later_unresolved_comment",
|
||||
url: c.html_url
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return { score, hardSignals, contradictions };
|
||||
}
|
||||
|
||||
async function callGitHubModel(issuePacket) {
|
||||
// Replace this stub with the GitHub Models inference call used by your org.
|
||||
// The workflow already has models: read permission.
|
||||
return {
|
||||
decision: "MANUAL_REVIEW",
|
||||
reason_code: "likely_fixed_but_unconfirmed",
|
||||
confidence: 0.74,
|
||||
hard_signals: [],
|
||||
contradictions: [],
|
||||
summary: "Potential resolution candidate; evidence is not strong enough to close automatically.",
|
||||
close_comment: "This appears resolved, so we’re closing it automatically. Reply if this is still reproducible.",
|
||||
manual_review_note: "Potential resolution candidate. Please review evidence before closing."
|
||||
};
|
||||
}
|
||||
|
||||
function enforcePolicy(modelOut, pre) {
|
||||
const approvedReasons = new Set([
|
||||
"resolved_by_merged_pr",
|
||||
"maintainer_confirmed_resolved",
|
||||
"duplicate_confirmed",
|
||||
"superseded_confirmed"
|
||||
]);
|
||||
|
||||
const hasHardSignal =
|
||||
(modelOut.hard_signals || []).some(s =>
|
||||
["merged_pr", "maintainer_comment", "duplicate_reference", "superseded_reference"].includes(s.type)
|
||||
) || pre.hardSignals.length > 0;
|
||||
|
||||
const hasContradiction =
|
||||
(modelOut.contradictions || []).length > 0 || pre.contradictions.length > 0;
|
||||
|
||||
if (
|
||||
modelOut.decision === "AUTO_CLOSE" &&
|
||||
modelOut.confidence >= 0.97 &&
|
||||
approvedReasons.has(modelOut.reason_code) &&
|
||||
hasHardSignal &&
|
||||
!hasContradiction
|
||||
) {
|
||||
return "AUTO_CLOSE";
|
||||
}
|
||||
|
||||
if (
|
||||
modelOut.decision === "MANUAL_REVIEW" ||
|
||||
modelOut.confidence >= 0.60 ||
|
||||
pre.score >= 25
|
||||
) {
|
||||
return "MANUAL_REVIEW";
|
||||
}
|
||||
|
||||
return "KEEP_OPEN";
|
||||
}
|
||||
|
||||
const decisions = [];
|
||||
for (const candidate of candidates) {
|
||||
const pre = preScore(candidate);
|
||||
const modelOut = await callGitHubModel(candidate);
|
||||
const finalDecision = enforcePolicy(modelOut, pre);
|
||||
|
||||
decisions.push({
|
||||
repository: candidate.repository,
|
||||
issue_number: candidate.issue.number,
|
||||
issue_url: candidate.issue.html_url,
|
||||
title: candidate.issue.title,
|
||||
pre_score: pre.score,
|
||||
final_decision: finalDecision,
|
||||
model: modelOut
|
||||
});
|
||||
}
|
||||
|
||||
await fs.writeFile("decisions.json", JSON.stringify(decisions, null, 2));
|
||||
50
.github/workflows/issue-resolution-triage.yml
vendored
Normal file
50
.github/workflows/issue-resolution-triage.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: issue-resolution-triage
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry_run:
|
||||
description: "If true, do not close issues"
|
||||
required: false
|
||||
default: "true"
|
||||
max_issues:
|
||||
description: "How many issues to process"
|
||||
required: false
|
||||
default: "100"
|
||||
schedule:
|
||||
- cron: "17 2 * * *"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
issues: write
|
||||
pull-requests: read
|
||||
models: read
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
DRY_RUN: ${{ inputs.dry_run || 'true' }}
|
||||
MAX_ISSUES: ${{ inputs.max_issues || '100' }}
|
||||
REPO: ${{ github.repository }}
|
||||
PROJECT_ID: ${{ vars.ISSUE_REVIEW_PROJECT_ID }}
|
||||
PROJECT_STATUS_FIELD_ID: ${{ vars.PROJECT_STATUS_FIELD_ID }}
|
||||
PROJECT_CONFIDENCE_FIELD_ID: ${{ vars.PROJECT_CONFIDENCE_FIELD_ID }}
|
||||
PROJECT_REASON_FIELD_ID: ${{ vars.PROJECT_REASON_FIELD_ID }}
|
||||
PROJECT_EVIDENCE_FIELD_ID: ${{ vars.PROJECT_EVIDENCE_FIELD_ID }}
|
||||
PROJECT_LINKED_PR_FIELD_ID: ${{ vars.PROJECT_LINKED_PR_FIELD_ID }}
|
||||
PROJECT_REPO_FIELD_ID: ${{ vars.PROJECT_REPO_FIELD_ID }}
|
||||
PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID: ${{ vars.PROJECT_STATUS_OPTION_NEEDS_REVIEW_ID }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
|
||||
- run: npm ci
|
||||
- run: node scripts/fetch-candidates.mjs
|
||||
- run: node scripts/classify-candidates.mjs
|
||||
- run: node scripts/apply-decisions.mjs
|
||||
@@ -151,7 +151,6 @@ func init() {
|
||||
rootCmd.AddCommand(logoutCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(vncCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
@@ -36,10 +36,7 @@ const (
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
jwtCacheTTLFlag = "jwt-cache-ttl"
|
||||
|
||||
// Alias for backward compatibility.
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -64,7 +61,7 @@ var (
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
jwtCacheTTL int
|
||||
sshJWTCacheTTL int
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -74,9 +71,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
|
||||
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
|
||||
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
|
||||
@@ -356,9 +356,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -374,12 +371,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
req.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||
req.SshJWTCacheTTL = &jwtCacheTTL32
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
@@ -464,9 +458,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -488,12 +479,8 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
ic.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &jwtCacheTTL
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
@@ -595,9 +582,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -619,13 +603,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
loginRequest.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
vncUsername string
|
||||
vncHost string
|
||||
vncMode string
|
||||
vncListen string
|
||||
vncNoBrowser bool
|
||||
vncNoCache bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
|
||||
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
|
||||
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
|
||||
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
}
|
||||
|
||||
var vncCmd = &cobra.Command{
|
||||
Use: "vnc [flags] [user@]host",
|
||||
Short: "Connect to a NetBird peer via VNC",
|
||||
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
|
||||
The target peer must have the VNC server enabled.
|
||||
|
||||
Two modes are available:
|
||||
- attach: view the current physical display (remote support)
|
||||
- session: start a virtual desktop as the specified user (passwordless login)
|
||||
|
||||
Use --listen to start a local proxy for external VNC viewers:
|
||||
netbird vnc --listen :5900 peer-hostname
|
||||
vncviewer localhost:5900
|
||||
|
||||
Examples:
|
||||
netbird vnc peer-hostname
|
||||
netbird vnc --mode session --user alice peer-hostname
|
||||
netbird vnc --listen :5900 peer-hostname`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: vncFn,
|
||||
}
|
||||
|
||||
func vncFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
if err := parseVNCHostArg(args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
vncCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runVNC(vncCtx, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-vncCtx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-vncCtx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCHostArg(arg string) error {
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return fmt.Errorf("invalid user@host format")
|
||||
}
|
||||
if vncUsername == "" {
|
||||
vncUsername = parts[0]
|
||||
}
|
||||
vncHost = parts[1]
|
||||
if vncMode == "attach" {
|
||||
vncMode = "session"
|
||||
}
|
||||
} else {
|
||||
vncHost = arg
|
||||
}
|
||||
|
||||
if vncMode == "session" && vncUsername == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
vncUsername = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
vncUsername = currentUser.Username
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runVNC(ctx context.Context, cmd *cobra.Command) error {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||
|
||||
if vncMode == "session" {
|
||||
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
|
||||
} else {
|
||||
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
|
||||
}
|
||||
|
||||
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
|
||||
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
|
||||
var jwtToken string
|
||||
hint := profilemanager.GetLoginHint()
|
||||
var browserOpener func(string) error
|
||||
if !vncNoBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
|
||||
if err != nil {
|
||||
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
|
||||
} else {
|
||||
jwtToken = token
|
||||
log.Debug("JWT authentication successful")
|
||||
}
|
||||
|
||||
// Connect to the VNC server on the standard port (5900). The peer's firewall
|
||||
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
|
||||
vncAddr := net.JoinHostPort(vncHost, "5900")
|
||||
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
|
||||
}
|
||||
defer vncConn.Close()
|
||||
|
||||
// Send session header with mode, username, and JWT.
|
||||
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
|
||||
return fmt.Errorf("send VNC header: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("VNC connected to %s\n", vncHost)
|
||||
|
||||
if vncListen != "" {
|
||||
return runVNCLocalProxy(ctx, cmd, vncConn)
|
||||
}
|
||||
|
||||
// No --listen flag: inform the user they need to use --listen for external viewers.
|
||||
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
|
||||
cmd.Printf("Press Ctrl+C to disconnect.\n")
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
const vncDialTimeout = 15 * time.Second
|
||||
|
||||
// sendVNCHeader writes the NetBird VNC session header.
|
||||
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
|
||||
var modeByte byte
|
||||
if mode == "session" {
|
||||
modeByte = 1
|
||||
}
|
||||
|
||||
usernameBytes := []byte(username)
|
||||
jwtBytes := []byte(jwt)
|
||||
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
|
||||
hdr[0] = modeByte
|
||||
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
|
||||
off := 3
|
||||
copy(hdr[off:], usernameBytes)
|
||||
off += len(usernameBytes)
|
||||
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
|
||||
off += 2
|
||||
copy(hdr[off:], jwtBytes)
|
||||
|
||||
_, err := conn.Write(hdr)
|
||||
return err
|
||||
}
|
||||
|
||||
// runVNCLocalProxy listens on the given address and proxies incoming
|
||||
// connections to the already-established VNC tunnel.
|
||||
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
|
||||
listener, err := net.Listen("tcp", vncListen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", vncListen, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
|
||||
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
// Accept a single viewer connection. VNC is single-session: the RFB
|
||||
// handshake completes on vncConn for the first viewer, so subsequent
|
||||
// viewers would get a mid-stream connection. The loop handles transient
|
||||
// accept errors until a valid connection arrives.
|
||||
for {
|
||||
clientConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
log.Debugf("accept VNC proxy client: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
|
||||
|
||||
// Bidirectional copy.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
io.Copy(vncConn, clientConn)
|
||||
close(done)
|
||||
}()
|
||||
io.Copy(clientConn, vncConn)
|
||||
<-done
|
||||
clientConn.Close()
|
||||
|
||||
cmd.Printf("VNC viewer disconnected\n")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var vncAgentPort string
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server in the current user session, listening on
|
||||
// localhost. It is spawned by the NetBird service (Session 0) via
|
||||
// CreateProcessAsUser into the interactive console session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Agent's stderr is piped to the service which relogs it.
|
||||
// Use JSON format with caller info for structured parsing.
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
|
||||
|
||||
capturer := vncserver.NewDesktopCapturer()
|
||||
injector := vncserver.NewWindowsInputInjector()
|
||||
srv := vncserver.New(capturer, injector, "")
|
||||
// Auth is handled by the service. The agent verifies a token on each
|
||||
// connection to ensure only the service process can connect.
|
||||
// The token is passed via environment variable to avoid exposing it
|
||||
// in the process command line (visible via tasklist/wmic).
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
|
||||
|
||||
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-cmd.Context().Done()
|
||||
return srv.Stop()
|
||||
},
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package cmd
|
||||
|
||||
const (
|
||||
serverVNCAllowedFlag = "allow-server-vnc"
|
||||
disableVNCAuthFlag = "disable-vnc-auth"
|
||||
)
|
||||
|
||||
var (
|
||||
serverVNCAllowed bool
|
||||
disableVNCAuth bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var vncRecDir string
|
||||
|
||||
func init() {
|
||||
vncRecPlayCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||
vncRecListCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||
vncRecCmd.AddCommand(vncRecListCmd)
|
||||
vncRecCmd.AddCommand(vncRecPlayCmd)
|
||||
vncRecCmd.AddCommand(vncRecKeygenCmd)
|
||||
vncCmd.AddCommand(vncRecCmd)
|
||||
}
|
||||
|
||||
var vncRecCmd = &cobra.Command{
|
||||
Use: "rec",
|
||||
Short: "Manage VNC session recordings",
|
||||
}
|
||||
|
||||
var vncRecKeygenCmd = &cobra.Command{
|
||||
Use: "keygen",
|
||||
Short: "Generate an X25519 keypair for recording encryption",
|
||||
Long: `Generates an X25519 keypair. Put the public key in management settings
|
||||
(Session Recording > Encryption Key). Keep the private key safe for decrypting recordings.`,
|
||||
RunE: vncRecKeygenFn,
|
||||
}
|
||||
|
||||
var vncRecListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
Short: "List VNC session recordings",
|
||||
RunE: vncRecListFn,
|
||||
}
|
||||
|
||||
var vncRecPlayCmd = &cobra.Command{
|
||||
Use: "play <file-or-name>",
|
||||
Short: "Open a VNC recording in the browser",
|
||||
Long: `Opens a browser-based player with playback controls:
|
||||
play/pause, seek, speed (0.25x to 8x), keyboard shortcuts.
|
||||
|
||||
Examples:
|
||||
netbird vnc rec play last
|
||||
netbird vnc rec play 20260416-104433_vnc.rec`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: vncRecPlayFn,
|
||||
}
|
||||
|
||||
|
||||
func vncRecListFn(cmd *cobra.Command, _ []string) error {
|
||||
dir, err := resolveVNCRecDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read recording dir %s: %w", dir, err)
|
||||
}
|
||||
|
||||
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||
fmt.Fprintln(w, "FILE\tSIZE\tDIMENSIONS\tUSER\tREMOTE\tMODE\tDATE")
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||
continue
|
||||
}
|
||||
filePath := filepath.Join(dir, entry.Name())
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(w, "%s\t%s\t?\t?\t?\t?\t?\n", entry.Name(), vncFormatSize(info.Size()))
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\t%s\t%dx%d\t%s\t%s\t%s\t%s\n",
|
||||
entry.Name(),
|
||||
vncFormatSize(info.Size()),
|
||||
header.Width, header.Height,
|
||||
header.Meta.User,
|
||||
header.Meta.RemoteAddr,
|
||||
header.Meta.Mode,
|
||||
header.StartTime.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
}
|
||||
|
||||
return w.Flush()
|
||||
}
|
||||
|
||||
func vncRecPlayFn(cmd *cobra.Command, args []string) error {
|
||||
filePath, err := resolveVNCRecFile(args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read recording: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("Recording: %s (%dx%d)\n", filepath.Base(filePath), header.Width, header.Height)
|
||||
|
||||
url, err := vncserver.ServeWebPlayer(filePath, "localhost:0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("start web player: %w", err)
|
||||
}
|
||||
cmd.Printf("Player: %s\n", url)
|
||||
if err := util.OpenBrowser(url); err != nil {
|
||||
cmd.Printf("Open %s in your browser\n", url)
|
||||
}
|
||||
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sig
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
func vncRecKeygenFn(cmd *cobra.Command, _ []string) error {
|
||||
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
privB64 := base64.StdEncoding.EncodeToString(priv.Bytes())
|
||||
pubB64 := base64.StdEncoding.EncodeToString(priv.PublicKey().Bytes())
|
||||
|
||||
cmd.Printf("Private key (keep secret, for decrypting recordings):\n %s\n\n", privB64)
|
||||
cmd.Printf("Public key (paste into management Settings > Session Recording > Encryption Key):\n %s\n", pubB64)
|
||||
return nil
|
||||
}
|
||||
|
||||
func vncFormatSize(size int64) string {
|
||||
switch {
|
||||
case size >= 1<<20:
|
||||
return fmt.Sprintf("%.1fM", float64(size)/float64(1<<20))
|
||||
case size >= 1<<10:
|
||||
return fmt.Sprintf("%.1fK", float64(size)/float64(1<<10))
|
||||
default:
|
||||
return fmt.Sprintf("%dB", size)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveVNCRecDir() (string, error) {
|
||||
if vncRecDir != "" {
|
||||
return vncRecDir, nil
|
||||
}
|
||||
candidates := []string{
|
||||
"/var/lib/netbird/recordings/vnc",
|
||||
filepath.Join(os.Getenv("HOME"), ".netbird/recordings/vnc"),
|
||||
}
|
||||
for _, dir := range candidates {
|
||||
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
|
||||
return dir, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no VNC recording directory found; use --dir to specify")
|
||||
}
|
||||
|
||||
func resolveVNCRecFile(arg string) (string, error) {
|
||||
if strings.Contains(arg, "/") || strings.Contains(arg, string(os.PathSeparator)) {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
dir, err := resolveVNCRecDir()
|
||||
if err != nil && arg != "last" {
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
if arg == "last" {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return findLatestRec(dir)
|
||||
}
|
||||
|
||||
full := filepath.Join(dir, arg)
|
||||
if _, err := os.Stat(full); err == nil {
|
||||
return full, nil
|
||||
}
|
||||
return arg, nil
|
||||
}
|
||||
|
||||
func findLatestRec(dir string) (string, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read dir: %w", err)
|
||||
}
|
||||
|
||||
var latest string
|
||||
var latestTime time.Time
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if info.ModTime().After(latestTime) {
|
||||
latestTime = info.ModTime()
|
||||
latest = filepath.Join(dir, entry.Name())
|
||||
}
|
||||
}
|
||||
if latest == "" {
|
||||
return "", fmt.Errorf("no recordings found in %s", dir)
|
||||
}
|
||||
return latest, nil
|
||||
}
|
||||
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||
// versions, which returns EPERM to any other process that tries to insert
|
||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||
// itself add the accept rules to its own chains by trusting the interface.
|
||||
package firewalld
|
||||
|
||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||
// should bypass firewalld filtering.
|
||||
const TrustedZone = "trusted"
|
||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusDest = "org.fedoraproject.FirewallD1"
|
||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||
|
||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||
errNotEnabled = "NOT_ENABLED"
|
||||
|
||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||
// A fresh context is created for each call so a slow DBus probe can't
|
||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||
callTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||
|
||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||
// Info level only for the first successful add per process; repeat adds
|
||||
// from other init paths are quieter.
|
||||
trustLogOnce sync.Once
|
||||
|
||||
parentCtxMu sync.RWMutex
|
||||
parentCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// SetParentContext installs a parent context whose cancellation aborts any
|
||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||
// during engine shutdown when the engine context is already cancelled.
|
||||
func SetParentContext(ctx context.Context) {
|
||||
parentCtxMu.Lock()
|
||||
parentCtx = ctx
|
||||
parentCtxMu.Unlock()
|
||||
}
|
||||
|
||||
func getParentContext() context.Context {
|
||||
parentCtxMu.RLock()
|
||||
defer parentCtxMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||
// running. It is idempotent and best-effort: errors are returned so callers
|
||||
// can log, but a non-running firewalld is not an error. Only the first
|
||||
// successful call per process logs at Info. Respects the parent context set
|
||||
// via SetParentContext so startup-time cancellation unblocks it.
|
||||
func TrustInterface(iface string) error {
|
||||
parent := getParentContext()
|
||||
if !isRunning(parent) {
|
||||
return nil
|
||||
}
|
||||
if err := addTrusted(parent, iface); err != nil {
|
||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
trustLogOnce.Do(func() {
|
||||
log.Infof("added %s to firewalld trusted zone", iface)
|
||||
})
|
||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||
// during shutdown after the engine context has been cancelled.
|
||||
func UntrustInterface(iface string) error {
|
||||
if !isRunning(context.Background()) {
|
||||
return nil
|
||||
}
|
||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(parent, callTimeout)
|
||||
}
|
||||
|
||||
func isRunning(parent context.Context) bool {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
ok, err := isRunningDBus(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return ok
|
||||
}
|
||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return isRunningCLI(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := addDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return addCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func removeTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := removeDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return removeCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
var zone string
|
||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isRunningCLI(ctx context.Context) bool {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return false
|
||||
}
|
||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||
}
|
||||
|
||||
func addDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||
if move.Err != nil {
|
||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func removeDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func addCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
// --change-interface (no --permanent) binds the interface for the
|
||||
// current runtime only; we do not want membership to persist across
|
||||
// reboots because netbird re-asserts it on every startup.
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbusErrContains(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var de dbus.Error
|
||||
if errors.As(err, &de) {
|
||||
for _, b := range de.Body {
|
||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), code)
|
||||
}
|
||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
func TestDBusErrContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code string
|
||||
want bool
|
||||
}{
|
||||
{"nil error", nil, errZoneAlreadySet, false},
|
||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||
{
|
||||
"dbus.Error body match",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||
errZoneAlreadySet,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"dbus.Error body miss",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||
errAlreadyEnabled,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"dbus.Error non-string body falls back to Error()",
|
||||
dbus.Error{Name: "x", Body: []any{123}},
|
||||
"x",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := dbusErrContains(tc.err, tc.code)
|
||||
if got != tc.want {
|
||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import "context"
|
||||
|
||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func SetParentContext(context.Context) {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
}
|
||||
|
||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func TrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func UntrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||
// interface in firewalld's trusted zone without a corresponding Close.
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
@@ -40,6 +41,8 @@ const (
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
|
||||
@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
NameFunc func() string
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Name() string {
|
||||
if i.NameFunc == nil {
|
||||
return "wgtest"
|
||||
}
|
||||
return i.NameFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||
// routing-correctness checks above are the real assertions; the counts
|
||||
// are a sanity bound to catch a totally silent path.
|
||||
minDelivered := packetsPerFamily * 80 / 100
|
||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
|
||||
@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
@@ -328,7 +327,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
a.config.DisableVNCAuth,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -546,13 +546,11 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DisableVNCAuth: config.DisableVNCAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
@@ -629,7 +627,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
@@ -642,7 +639,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
config.DisableVNCAuth,
|
||||
)
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,12 @@ package debug
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
||||
t.Skip("Skipping upload test on docker ci")
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
testURL := "http://localhost:8080"
|
||||
addr := reserveLoopbackPort(t)
|
||||
testURL := "http://" + addr
|
||||
t.Setenv("SERVER_URL", testURL)
|
||||
t.Setenv("SERVER_ADDRESS", addr)
|
||||
t.Setenv("STORE_DIR", testDir)
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
})
|
||||
waitForServer(t, addr)
|
||||
|
||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||
fileContent := []byte("test file content")
|
||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
|
||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||
// address, then releases it so the server under test can rebind. The close/
|
||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||
// it's essentially never contended in practice.
|
||||
func reserveLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func waitForServer(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server did not start listening on %s in time", addr)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
c.dispatch(w, r, math.MaxInt)
|
||||
}
|
||||
|
||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if entry.Priority > maxPriority {
|
||||
continue
|
||||
}
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||
// (bounded by the invoked handler's internal timeout).
|
||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
base := &internalResponseWriter{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.dispatch(base, r, maxPriority)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||
}
|
||||
return base.response, nil
|
||||
}
|
||||
|
||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||
// priority ≤ maxPriority.
|
||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for _, h := range c.handlers {
|
||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||
type internalResponseWriter struct {
|
||||
response *dns.Msg
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||
// still surface their answer to ResolveInternal.
|
||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.response = msg
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) Close() error { return nil }
|
||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||
|
||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||
// no-op: in-process queries carry no TSIG state.
|
||||
}
|
||||
|
||||
// Hijack is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) Hijack() {
|
||||
// no-op: in-process queries have no underlying connection to hand off.
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||
// which handler ResolveInternal dispatches to.
|
||||
type answeringHandler struct {
|
||||
name string
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *answeringHandler) String() string { return h.name }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, 1, len(resp.Answer))
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||
}
|
||||
|
||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||
type rawWriteHandler struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
packed, err := resp.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(packed)
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||
type hangingHandler struct {
|
||||
block chan struct{}
|
||||
}
|
||||
|
||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
<-h.block
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &hangingHandler{block: make(chan struct{})}
|
||||
defer close(h.block)
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||
}
|
||||
|
||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||
|
||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
|
||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||
|
||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
return fileManager, nil
|
||||
break
|
||||
}
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return fileManager, nil
|
||||
return 0, "", rejected, nil
|
||||
}
|
||||
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,40 +2,83 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const dnsTimeout = 5 * time.Second
|
||||
const (
|
||||
dnsTimeout = 5 * time.Second
|
||||
defaultTTL = 300 * time.Second
|
||||
refreshBackoff = 30 * time.Second
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains
|
||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
type ChainResolver interface {
|
||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||
}
|
||||
|
||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||
// records and cachedAt are set at construction and treated as immutable;
|
||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||
// Resolver.mutex.
|
||||
type cachedRecord struct {
|
||||
records []dns.RR
|
||||
cachedAt time.Time
|
||||
lastFailedRefresh time.Time
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question][]dns.RR
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
// the OS resolver routed the recursive query back to us (loop). Only
|
||||
// the OS path arms this so chain-path refreshes don't produce false
|
||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||
// throttle the warning log.
|
||||
refreshing map[dns.Question]*atomic.Bool
|
||||
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
|
||||
return "MgmtCacheResolver"
|
||||
}
|
||||
|
||||
// ServeDNS implements dns.Handler interface.
|
||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||
// maxPriority caps which handlers may answer refresh queries (typically
|
||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||
// mgmt/route/local handlers are skipped).
|
||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||
m.mutex.Lock()
|
||||
m.chain = chain
|
||||
m.chainMaxPriority = maxPriority
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
m.continueToNext(w, r)
|
||||
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
records, found := m.records[question]
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||
shouldRefresh = stale && !inBackoff
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||
question.Name)
|
||||
}
|
||||
|
||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||
|
||||
if errA != nil && errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
aRecords = append(aRecords, rr)
|
||||
} else if ip.Is6() {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
aaaaRecords = append(aaaaRecords, rr)
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if len(aRecords) > 0 {
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aQuestion] = aRecords
|
||||
}
|
||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||
|
||||
if len(aaaaRecords) > 0 {
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aaaaQuestion] = aaaaRecords
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||
// untouched on error. Caller holds m.mutex.
|
||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
switch {
|
||||
case len(records) > 0:
|
||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||
case err == nil:
|
||||
delete(m.records, q)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails == 0 || fails > 1 {
|
||||
logf = log.Debugf
|
||||
}
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
return resp.ips, nil
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
m.refreshing[question] = &atomic.Bool{}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
delete(m.refreshing, question)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
c.consecFailures++
|
||||
return c.consecFailures
|
||||
}
|
||||
|
||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||
// spot the OS resolver routing the recursive query back to us.
|
||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver.
|
||||
m.markRefreshing(q)
|
||||
defer m.clearRefreshing(q)
|
||||
|
||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(dnsName, qtype)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||
}
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
ttl := uint32(m.cacheTTL.Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
|
||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||
if result.Rcode == dns.RcodeSuccess {
|
||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||
}
|
||||
|
||||
if result.Err != nil {
|
||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aQuestion)
|
||||
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aaaaQuestion)
|
||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||
delete(m.records, qA)
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||
func resolveCacheTTL() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
}
|
||||
|
||||
func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.hasRoot
|
||||
}
|
||||
|
||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
f.mu.Lock()
|
||||
q := msg.Question[0]
|
||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
if onLookup != nil {
|
||||
onLookup()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(msg)
|
||||
resp.Answer = answers
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
key := name + "|" + dns.TypeToString[qtype]
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||
case dns.TypeAAAA:
|
||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||
}
|
||||
|
||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(d)
|
||||
for time.Now().Before(deadline) {
|
||||
if fn() {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %s", d)
|
||||
}
|
||||
|
||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||
t.Helper()
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(name, dns.TypeA)
|
||||
w := &test.MockResponseWriter{}
|
||||
r.ServeDNS(w, msg)
|
||||
return w.GetLastResponse()
|
||||
}
|
||||
|
||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||
t.Helper()
|
||||
require.NotNil(t, resp)
|
||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok, "expected A record")
|
||||
return a.A.String()
|
||||
}
|
||||
|
||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||
// trigger a background refresh, the longer one must not. Proves that the
|
||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||
|
||||
newRec := func() *cachedRecord {
|
||||
return &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: cachedAt,
|
||||
}
|
||||
}
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(), // fresh
|
||||
}
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||
}
|
||||
|
||||
// First query: serves stale immediately.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||
})
|
||||
|
||||
// Next query should now return the refreshed IP.
|
||||
waitFor(t, time.Second, func() bool {
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
var inflight atomic.Int32
|
||||
var maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
cur := inflight.Add(1)
|
||||
defer inflight.Add(-1)
|
||||
for {
|
||||
prev := maxInflight.Load()
|
||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||
}
|
||||
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool {
|
||||
return inflight.Load() == 0
|
||||
})
|
||||
|
||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
// First stale hit triggers a refresh attempt that fails.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||
})
|
||||
waitFor(t, time.Second, func() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
c, ok := r.records[q]
|
||||
return ok && !c.lastFailedRefresh.IsZero()
|
||||
})
|
||||
|
||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||
for i := 0; i < 10; i++ {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
// With hasRoot=false the chain must not be consulted. Use a short
|
||||
// deadline so the OS fallback returns quickly without waiting on a
|
||||
// real network call in CI.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||
"chain must not be used when no root handler is registered at the bound priority")
|
||||
}
|
||||
|
||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Simulate an inflight refresh.
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
require.NotNil(t, inflight)
|
||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||
for range 5 {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.True(t, inflight.Load())
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
|
||||
r.mutex.RLock()
|
||||
_, ok := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
|
||||
assert.False(t, resolver.MatchSubdomains())
|
||||
}
|
||||
|
||||
func TestResolveCacheTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want time.Duration
|
||||
}{
|
||||
{"unset falls back to default", "", defaultTTL},
|
||||
{"valid duration", "45s", 45 * time.Second},
|
||||
{"valid minutes", "2m", 2 * time.Minute},
|
||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||
{"zero falls back to default", "0s", defaultTTL},
|
||||
{"negative falls back to default", "-5s", defaultTTL},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||
got := resolveCacheTTL()
|
||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
func TestResolver_ResponseTTL(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheTTL time.Duration
|
||||
cachedAt time.Time
|
||||
wantMin uint32
|
||||
wantMax uint32
|
||||
}{
|
||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||
got := r.responseTTL(tc.cachedAt)
|
||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -212,6 +212,7 @@ func newDefaultServer(
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -117,13 +118,11 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -200,7 +199,6 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -314,10 +312,6 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -577,7 +571,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -611,6 +605,8 @@ func (e *Engine) createFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
firewalld.SetParentContext(e.ctx)
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
@@ -1005,7 +1001,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1018,7 +1013,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -1046,10 +1040,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
@@ -1152,7 +1142,6 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1165,7 +1154,6 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1340,11 +1328,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
// VNC auth: use dedicated VNCAuth if present.
|
||||
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
|
||||
e.updateVNCServerAuth(vncAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1754,7 +1737,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1767,7 +1749,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
|
||||
@@ -1,309 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const envVNCForceRecording = "NB_VNC_FORCE_RECORDING"
|
||||
|
||||
const (
|
||||
vncExternalPort uint16 = 5900
|
||||
vncInternalPort uint16 = 25900
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
// sshConf provides the JWT identity provider config (shared with SSH).
|
||||
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
// Update JWT config on existing server in case management sent new config.
|
||||
e.updateVNCServerJWT(sshConf)
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer(sshConf)
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector := newPlatformVNC()
|
||||
if capturer == nil || injector == nil {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector, "")
|
||||
if vncNeedsServiceMode() {
|
||||
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
// Configure VNC authentication.
|
||||
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||
log.Info("VNC: authentication disabled by config")
|
||||
srv.SetDisableAuth(true)
|
||||
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
srv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
|
||||
}
|
||||
|
||||
e.configureVNCRecording(srv, sshConf)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureVNCRecording enables session recording on the VNC server from the
|
||||
// management-supplied settings. The env var NB_VNC_FORCE_RECORDING overrides
|
||||
// the API for local development: when set, recording is always enabled and
|
||||
// writes into that directory. Otherwise recordings go next to the state file
|
||||
// under vnc-recordings/.
|
||||
func (e *Engine) configureVNCRecording(srv *vncserver.Server, sshConf *mgmProto.SSHConfig) {
|
||||
recDir := os.Getenv(envVNCForceRecording)
|
||||
apiEnabled := sshConf.GetEnableRecording()
|
||||
|
||||
if recDir == "" && !apiEnabled {
|
||||
log.Debugf("VNC recording disabled (env=%q, api=%v)", recDir, apiEnabled)
|
||||
return
|
||||
}
|
||||
|
||||
if recDir == "" {
|
||||
base := e.defaultRecordingBase()
|
||||
if base == "" {
|
||||
log.Warn("VNC recording requested by management but no state directory is available")
|
||||
return
|
||||
}
|
||||
recDir = filepath.Join(base, "vnc-recordings")
|
||||
} else {
|
||||
recDir = filepath.Join(recDir, "vnc")
|
||||
}
|
||||
|
||||
srv.SetRecordingDir(recDir)
|
||||
log.Infof("VNC recording enabled (dir=%s, source=%s)", recDir, recordingSource(apiEnabled))
|
||||
|
||||
encKey := string(sshConf.GetRecordingEncryptionKey())
|
||||
if encKey == "" {
|
||||
encKey = os.Getenv("NB_VNC_RECORDING_ENCRYPTION_KEY")
|
||||
}
|
||||
if encKey != "" {
|
||||
srv.SetRecordingEncryptionKey(encKey)
|
||||
log.Info("VNC recording encryption enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) defaultRecordingBase() string {
|
||||
if e.stateManager == nil {
|
||||
return ""
|
||||
}
|
||||
p := e.stateManager.FilePath()
|
||||
if p == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Dir(p)
|
||||
}
|
||||
|
||||
func recordingSource(api bool) string {
|
||||
if api {
|
||||
return "management"
|
||||
}
|
||||
return "env"
|
||||
}
|
||||
|
||||
// updateVNCServerJWT configures the JWT validation for the VNC server using
|
||||
// the same JWT config as SSH (same identity provider).
|
||||
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||
vncSrv.SetDisableAuth(true)
|
||||
return
|
||||
}
|
||||
|
||||
protoJWT := sshConf.GetJwtConfig()
|
||||
if protoJWT == nil {
|
||||
return
|
||||
}
|
||||
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
|
||||
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if vncAuth == nil || e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
UserIDClaim: vncAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running.
|
||||
func (e *Engine) GetVNCServerStatus() bool {
|
||||
return e.vncSrv != nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}
|
||||
}
|
||||
return capturer, injector
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
capturer := vncserver.NewX11Poller("")
|
||||
injector, err := vncserver.NewX11InputInjector("")
|
||||
if err != nil {
|
||||
log.Debugf("VNC: X11 input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}
|
||||
}
|
||||
return capturer, injector
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
if !forceRelay {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
defer func() {
|
||||
if !connected {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,6 +13,20 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
201
client/internal/peer/conn_status_eval_test.go
Normal file
201
client/internal/peer/conn_status_eval_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
func IsForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -8,7 +8,19 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
61
client/internal/peer/guard/ice_retry_state.go
Normal file
61
client/internal/peer/guard/ice_retry_state.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start() {
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +44,10 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -59,6 +64,10 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -66,7 +75,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
return &Handshaker{
|
||||
h := &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -64,13 +64,11 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
@@ -116,13 +114,11 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
|
||||
DisableClientRoutes bool
|
||||
@@ -419,21 +415,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
@@ -484,16 +465,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
|
||||
if *input.DisableVNCAuth {
|
||||
log.Infof("disabling VNC authentication")
|
||||
} else {
|
||||
log.Infof("enabling VNC authentication")
|
||||
}
|
||||
config.DisableVNCAuth = input.DisableVNCAuth
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
|
||||
@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -209,9 +209,6 @@ message LoginRequest {
|
||||
optional bool enableSSHRemotePortForwarding = 37;
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
optional bool disableVNCAuth = 42;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -319,10 +316,6 @@ message GetConfigResponse {
|
||||
bool disableSSHAuth = 25;
|
||||
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
|
||||
bool disableVNCAuth = 29;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -401,11 +394,6 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -420,7 +408,6 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -690,9 +677,6 @@ message SetConfigRequest {
|
||||
optional bool enableSSHRemotePortForwarding = 32;
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
optional bool disableVNCAuth = 37;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
|
||||
@@ -369,7 +369,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -386,9 +385,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
if msg.DisableVNCAuth != nil {
|
||||
config.DisableVNCAuth = msg.DisableVNCAuth
|
||||
}
|
||||
if msg.SshJWTCacheTTL != nil {
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
@@ -1127,7 +1123,6 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1167,26 +1162,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &proto.VNCServerState{
|
||||
Enabled: engine.GetVNCServerStatus(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1528,11 +1503,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableSSHAuth = *cfg.DisableSSHAuth
|
||||
}
|
||||
|
||||
disableVNCAuth := false
|
||||
if cfg.DisableVNCAuth != nil {
|
||||
disableVNCAuth = *cfg.DisableVNCAuth
|
||||
}
|
||||
|
||||
sshJWTCacheTTL := int32(0)
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||
@@ -1547,7 +1517,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
@@ -1563,7 +1532,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
DisableVNCAuth: disableVNCAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverVNCAllowed := true
|
||||
disableVNCAuth := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -84,8 +82,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerVNCAllowed: &serverVNCAllowed,
|
||||
DisableVNCAuth: &disableVNCAuth,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -129,10 +125,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||
require.NotNil(t, cfg.DisableVNCAuth)
|
||||
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -184,8 +176,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerVNCAllowed": true,
|
||||
"DisableVNCAuth": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -246,8 +236,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-vnc": "ServerVNCAllowed",
|
||||
"disable-vnc-auth": "DisableVNCAuth",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
|
||||
}
|
||||
}
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication.
|
||||
// This is the same approach OpenSSH for Windows uses for public key authentication.
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
|
||||
@@ -507,7 +507,27 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
maxTokenAge = DefaultJWTMaxTokenAge
|
||||
}
|
||||
|
||||
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
@@ -538,7 +558,27 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
return jwt.UserIDFromToken(token)
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
|
||||
@@ -130,10 +130,6 @@ type SSHServerStateOutput struct {
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -155,7 +151,6 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -176,9 +171,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
vncServerOverview := VNCServerStateOutput{
|
||||
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
|
||||
}
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
@@ -202,7 +194,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
VNCServerState: vncServerOverview,
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
@@ -533,11 +524,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
vncServerStatus := "Disabled"
|
||||
if o.VNCServerState.Enabled {
|
||||
vncServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
@@ -567,7 +553,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"VNC Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
@@ -585,7 +570,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
vncServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
|
||||
@@ -398,9 +398,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
},
|
||||
"vncServer":{
|
||||
"enabled":false
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
@@ -508,8 +505,6 @@ profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
vncServer:
|
||||
enabled: false
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -577,7 +572,6 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
@@ -602,7 +596,6 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -62,7 +62,6 @@ type Info struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -78,27 +77,21 @@ type Info struct {
|
||||
EnableSSHLocalPortForwarding bool
|
||||
EnableSSHRemotePortForwarding bool
|
||||
DisableSSHAuth bool
|
||||
DisableVNCAuth bool
|
||||
}
|
||||
|
||||
func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
serverVNCAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
disableSSHAuth *bool,
|
||||
disableVNCAuth *bool,
|
||||
) {
|
||||
i.RosenpassEnabled = rosenpassEnabled
|
||||
i.RosenpassPermissive = rosenpassPermissive
|
||||
if serverSSHAllowed != nil {
|
||||
i.ServerSSHAllowed = *serverSSHAllowed
|
||||
}
|
||||
if serverVNCAllowed != nil {
|
||||
i.ServerVNCAllowed = *serverVNCAllowed
|
||||
}
|
||||
|
||||
i.DisableClientRoutes = disableClientRoutes
|
||||
i.DisableServerRoutes = disableServerRoutes
|
||||
@@ -124,9 +117,6 @@ func (i *Info) SetFlags(
|
||||
if disableSSHAuth != nil {
|
||||
i.DisableSSHAuth = *disableSSHAuth
|
||||
}
|
||||
if disableVNCAuth != nil {
|
||||
i.DisableVNCAuth = *disableVNCAuth
|
||||
}
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
|
||||
@@ -1,474 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
agentPort = "15900"
|
||||
|
||||
// agentTokenLen is the length of the random authentication token
|
||||
// used to verify that connections to the agent come from the service.
|
||||
agentTokenLen = 32
|
||||
|
||||
stillActive = 259
|
||||
|
||||
tokenPrimary = 1
|
||||
securityImpersonation = 2
|
||||
tokenSessionID = 12
|
||||
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
createNoWindow = 0x08000000
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||
|
||||
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||
|
||||
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||
)
|
||||
|
||||
// GetCurrentSessionID returns the session ID of the current process.
|
||||
func GetCurrentSessionID() uint32 {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_QUERY, &token); err != nil {
|
||||
return 0
|
||||
}
|
||||
defer token.Close()
|
||||
var id uint32
|
||||
var ret uint32
|
||||
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||
return id
|
||||
}
|
||||
|
||||
func getConsoleSessionID() uint32 {
|
||||
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||
return uint32(r)
|
||||
}
|
||||
|
||||
const (
|
||||
wtsActive = 0
|
||||
wtsConnected = 1
|
||||
wtsDisconnected = 4
|
||||
)
|
||||
|
||||
type wtsSessionInfo struct {
|
||||
SessionID uint32
|
||||
WinStationName [66]byte // actually *uint16, but we just need the struct size
|
||||
State uint32
|
||||
}
|
||||
|
||||
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||
// Prefers an active (logged-in, interactive) session over the console session.
|
||||
// This avoids kicking out an RDP user when the console is at the login screen.
|
||||
func getActiveSessionID() uint32 {
|
||||
var sessionInfo uintptr
|
||||
var count uint32
|
||||
|
||||
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
0, // reserved
|
||||
1, // version
|
||||
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||
uintptr(unsafe.Pointer(&count)),
|
||||
)
|
||||
if r == 0 || count == 0 {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
defer procWTSFreeMemory.Call(sessionInfo)
|
||||
|
||||
type wtsSession struct {
|
||||
SessionID uint32
|
||||
Station *uint16
|
||||
State uint32
|
||||
}
|
||||
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||
|
||||
// Find the first active session (not session 0, which is the services session).
|
||||
var bestID uint32
|
||||
found := false
|
||||
for _, s := range sessions {
|
||||
if s.SessionID == 0 {
|
||||
continue
|
||||
}
|
||||
if s.State == wtsActive {
|
||||
bestID = s.SessionID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
return bestID
|
||||
}
|
||||
|
||||
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||
var cur windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||
}
|
||||
defer cur.Close()
|
||||
|
||||
var dup windows.Token
|
||||
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||
}
|
||||
|
||||
sid := sessionID
|
||||
r, _, err := procSetTokenInformation.Call(
|
||||
uintptr(dup),
|
||||
uintptr(tokenSessionID),
|
||||
uintptr(unsafe.Pointer(&sid)),
|
||||
unsafe.Sizeof(sid),
|
||||
)
|
||||
if r == 0 {
|
||||
dup.Close()
|
||||
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||
}
|
||||
return dup, nil
|
||||
}
|
||||
|
||||
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
|
||||
|
||||
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||
// an extra null. Returns a new block pointer with the entry added.
|
||||
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
|
||||
entry := key + "=" + value
|
||||
|
||||
// Walk the existing block to find its total length.
|
||||
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||
var totalChars int
|
||||
for {
|
||||
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||
if ch == 0 {
|
||||
// Check for double-null terminator.
|
||||
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||
totalChars++
|
||||
if next == 0 {
|
||||
// End of block (don't count the final null yet, we'll rebuild).
|
||||
break
|
||||
}
|
||||
} else {
|
||||
totalChars++
|
||||
}
|
||||
}
|
||||
|
||||
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||
// New block: existing entries + new entry (null-terminated) + final null.
|
||||
newLen := totalChars + len(entryUTF16) + 1
|
||||
newBlock := make([]uint16, newLen)
|
||||
// Copy existing entries (up to but not including the final null).
|
||||
for i := range totalChars {
|
||||
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||
}
|
||||
copy(newBlock[totalChars:], entryUTF16)
|
||||
newBlock[newLen-1] = 0 // final null terminator
|
||||
|
||||
return uintptr(unsafe.Pointer(&newBlock[0]))
|
||||
}
|
||||
|
||||
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
|
||||
token, err := getSystemTokenForSession(sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var envBlock uintptr
|
||||
r, _, _ := procCreateEnvironmentBlock.Call(
|
||||
uintptr(unsafe.Pointer(&envBlock)),
|
||||
uintptr(token),
|
||||
0,
|
||||
)
|
||||
if r != 0 {
|
||||
defer procDestroyEnvironmentBlock.Call(envBlock)
|
||||
}
|
||||
|
||||
// Inject the auth token into the environment block so it doesn't appear
|
||||
// in the process command line (visible via tasklist/wmic).
|
||||
if r != 0 {
|
||||
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||
}
|
||||
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get executable path: %w", err)
|
||||
}
|
||||
|
||||
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
|
||||
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||
}
|
||||
|
||||
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||
// its output in the service process.
|
||||
var sa windows.SecurityAttributes
|
||||
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||
sa.InheritHandle = 1
|
||||
|
||||
var stderrRead, stderrWrite windows.Handle
|
||||
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
// The read end must NOT be inherited by the child.
|
||||
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||
|
||||
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||
si := windows.StartupInfo{
|
||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||
Desktop: desktop,
|
||||
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||
ShowWindow: 0,
|
||||
StdErr: stderrWrite,
|
||||
StdOutput: stderrWrite,
|
||||
}
|
||||
var pi windows.ProcessInformation
|
||||
|
||||
var envPtr *uint16
|
||||
if envBlock != 0 {
|
||||
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||
}
|
||||
|
||||
err = windows.CreateProcessAsUser(
|
||||
token, nil, cmdLineW,
|
||||
nil, nil, true, // inheritHandles=true for the pipe
|
||||
createUnicodeEnvironment|createNoWindow,
|
||||
envPtr, nil, &si, &pi,
|
||||
)
|
||||
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||
windows.CloseHandle(stderrWrite)
|
||||
if err != nil {
|
||||
windows.CloseHandle(stderrRead)
|
||||
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||
}
|
||||
windows.CloseHandle(pi.Thread)
|
||||
|
||||
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||
go relogAgentOutput(stderrRead)
|
||||
|
||||
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
|
||||
return pi.Process, nil
|
||||
}
|
||||
|
||||
// sessionManager monitors the active console session and ensures a VNC agent
|
||||
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||
// connect/disconnect), it kills the old agent and spawns a new one.
|
||||
type sessionManager struct {
|
||||
port string
|
||||
mu sync.Mutex
|
||||
agentProc windows.Handle
|
||||
sessionID uint32
|
||||
authToken string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newSessionManager(port string) *sessionManager {
|
||||
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
|
||||
}
|
||||
|
||||
// generateAuthToken creates a new random hex token for agent authentication.
|
||||
func generateAuthToken() string {
|
||||
b := make([]byte, agentTokenLen)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
log.Warnf("generate agent auth token: %v", err)
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// AuthToken returns the current agent authentication token.
|
||||
func (m *sessionManager) AuthToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.authToken
|
||||
}
|
||||
|
||||
// Stop signals the session manager to exit its polling loop.
|
||||
func (m *sessionManager) Stop() {
|
||||
select {
|
||||
case <-m.done:
|
||||
default:
|
||||
close(m.done)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sessionManager) run() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
sid := getActiveSessionID()
|
||||
|
||||
m.mu.Lock()
|
||||
if sid != m.sessionID {
|
||||
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||
m.killAgent()
|
||||
m.sessionID = sid
|
||||
}
|
||||
|
||||
if m.agentProc != 0 {
|
||||
var code uint32
|
||||
_ = windows.GetExitCodeProcess(m.agentProc, &code)
|
||||
if code != stillActive {
|
||||
log.Infof("agent exited (code=%d), respawning", code)
|
||||
windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
}
|
||||
}
|
||||
|
||||
if m.agentProc == 0 && sid != 0xFFFFFFFF {
|
||||
m.authToken = generateAuthToken()
|
||||
h, err := spawnAgentInSession(sid, m.port, m.authToken)
|
||||
if err != nil {
|
||||
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||
m.authToken = ""
|
||||
} else {
|
||||
m.agentProc = h
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-m.done:
|
||||
m.mu.Lock()
|
||||
m.killAgent()
|
||||
m.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sessionManager) killAgent() {
|
||||
if m.agentProc != 0 {
|
||||
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||
windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
log.Info("killed old agent")
|
||||
}
|
||||
}
|
||||
|
||||
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
|
||||
// relogs them at the correct level with the service's formatter.
|
||||
func relogAgentOutput(pipe windows.Handle) {
|
||||
defer windows.CloseHandle(pipe)
|
||||
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||
defer f.Close()
|
||||
|
||||
entry := log.WithField("component", "vnc-agent")
|
||||
dec := json.NewDecoder(f)
|
||||
for dec.More() {
|
||||
var m map[string]any
|
||||
if err := dec.Decode(&m); err != nil {
|
||||
break
|
||||
}
|
||||
msg, _ := m["msg"].(string)
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Forward extra fields from the agent (skip standard logrus fields).
|
||||
// Remap "caller" to "source" so it doesn't conflict with logrus internals
|
||||
// but still shows the original file/line from the agent process.
|
||||
fields := make(log.Fields)
|
||||
for k, v := range m {
|
||||
switch k {
|
||||
case "msg", "level", "time", "func":
|
||||
continue
|
||||
case "caller":
|
||||
fields["source"] = v
|
||||
default:
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
e := entry.WithFields(fields)
|
||||
|
||||
switch m["level"] {
|
||||
case "error":
|
||||
e.Error(msg)
|
||||
case "warning":
|
||||
e.Warn(msg)
|
||||
case "debug":
|
||||
e.Debug(msg)
|
||||
case "trace":
|
||||
e.Trace(msg)
|
||||
default:
|
||||
e.Info(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToAgent connects to the agent, sends the auth token, then proxies
|
||||
// the VNC client connection bidirectionally.
|
||||
func proxyToAgent(client net.Conn, port string, authToken string) {
|
||||
defer client.Close()
|
||||
|
||||
addr := "127.0.0.1:" + port
|
||||
var agentConn net.Conn
|
||||
var err error
|
||||
for range 50 {
|
||||
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
|
||||
return
|
||||
}
|
||||
defer agentConn.Close()
|
||||
|
||||
// Send the auth token so the agent can verify this connection
|
||||
// comes from the trusted service process.
|
||||
tokenBytes, _ := hex.DecodeString(authToken)
|
||||
if _, err := agentConn.Write(tokenBytes); err != nil {
|
||||
log.Warnf("send auth token to agent: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
cp := func(label string, dst, src net.Conn) {
|
||||
n, err := io.Copy(dst, src)
|
||||
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||
done <- struct{}{}
|
||||
}
|
||||
go cp("client→agent", agentConn, client)
|
||||
go cp("agent→client", client, agentConn)
|
||||
<-done
|
||||
}
|
||||
@@ -1,486 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
||||
var darwinCaptureOnce sync.Once
|
||||
|
||||
var (
|
||||
cgMainDisplayID func() uint32
|
||||
cgDisplayPixelsWide func(uint32) uintptr
|
||||
cgDisplayPixelsHigh func(uint32) uintptr
|
||||
cgDisplayCreateImage func(uint32) uintptr
|
||||
cgImageGetWidth func(uintptr) uintptr
|
||||
cgImageGetHeight func(uintptr) uintptr
|
||||
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||
cgImageGetDataProvider func(uintptr) uintptr
|
||||
cgDataProviderCopyData func(uintptr) uintptr
|
||||
cgImageRelease func(uintptr)
|
||||
cfDataGetLength func(uintptr) int64
|
||||
cfDataGetBytePtr func(uintptr) uintptr
|
||||
cfRelease func(uintptr)
|
||||
cgPreflightScreenCaptureAccess func() bool
|
||||
cgRequestScreenCaptureAccess func() bool
|
||||
darwinCaptureReady bool
|
||||
)
|
||||
|
||||
func initDarwinCapture() {
|
||||
darwinCaptureOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
|
||||
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
|
||||
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||
}
|
||||
|
||||
darwinCaptureReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// errFrameUnchanged signals that the raw capture bytes matched the previous
|
||||
// frame, so the caller can skip the expensive BGRA to RGBA conversion.
|
||||
var errFrameUnchanged = errors.New("frame unchanged")
|
||||
|
||||
// CGCapturer captures the macOS main display using Core Graphics.
|
||||
type CGCapturer struct {
|
||||
displayID uint32
|
||||
w, h int
|
||||
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
|
||||
downscale int
|
||||
hashSeed maphash.Seed
|
||||
lastHash uint64
|
||||
hasHash bool
|
||||
}
|
||||
|
||||
// NewCGCapturer creates a screen capturer for the main display.
|
||||
func NewCGCapturer() (*CGCapturer, error) {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available")
|
||||
}
|
||||
|
||||
// Request Screen Recording permission (shows system dialog on macOS 11+).
|
||||
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
openPrivacyPane("Privacy_ScreenCapture")
|
||||
log.Warn("Screen Recording permission not granted. " +
|
||||
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
|
||||
}
|
||||
|
||||
displayID := cgMainDisplayID()
|
||||
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
|
||||
|
||||
// Probe actual pixel dimensions via a test capture. CGDisplayPixelsWide/High
|
||||
// returns logical points on Retina, but CGDisplayCreateImage produces native
|
||||
// pixels (often 2x), so probing the image is the only reliable source.
|
||||
img, err := c.Capture()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("probe capture: %w", err)
|
||||
}
|
||||
nativeW := img.Rect.Dx()
|
||||
nativeH := img.Rect.Dy()
|
||||
c.hasHash = false
|
||||
if nativeW == 0 || nativeH == 0 {
|
||||
return nil, errors.New("display dimensions are zero")
|
||||
}
|
||||
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
|
||||
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
|
||||
// count 4x, shrinking convert, diff, and wire data proportionally.
|
||||
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
|
||||
c.downscale = 2
|
||||
}
|
||||
c.w = nativeW / c.downscale
|
||||
c.h = nativeH / c.downscale
|
||||
|
||||
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
|
||||
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func retinaDownscaleDisabled() bool {
|
||||
v := os.Getenv(EnvVNCDisableDownscale)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *CGCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *CGCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, fmt.Errorf("empty image data")
|
||||
}
|
||||
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return nil, errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
|
||||
bytesPerPixel := bpp / 8
|
||||
if bytesPerPixel == 4 && ds == 1 {
|
||||
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
|
||||
} else if bytesPerPixel == 4 && ds == 2 {
|
||||
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
|
||||
} else {
|
||||
for row := 0; row < outH; row++ {
|
||||
srcOff := row * ds * bytesPerRow
|
||||
dstOff := row * img.Stride
|
||||
for col := 0; col < outW; col++ {
|
||||
si := srcOff + col*ds*bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
img.Pix[di+0] = src[si+2]
|
||||
img.Pix[di+1] = src[si+1]
|
||||
img.Pix[di+2] = src[si+0]
|
||||
img.Pix[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
|
||||
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
|
||||
// destination dimensions (source is 2*outW by 2*outH).
|
||||
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > outH {
|
||||
workers = outH
|
||||
}
|
||||
if workers < 1 || outH < 32 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
for row := y0; row < y1; row++ {
|
||||
srcRow0 := 2 * row * srcStride
|
||||
srcRow1 := srcRow0 + srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < outW; col++ {
|
||||
s0 := srcRow0 + col*8
|
||||
s1 := srcRow1 + col*8
|
||||
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
|
||||
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
|
||||
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = byte(r)
|
||||
dst[di+1] = byte(g)
|
||||
dst[di+2] = byte(b)
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, outH)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (outH + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > outH {
|
||||
y1 = outH
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
|
||||
// parallelises across GOMAXPROCS cores for large images.
|
||||
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > h {
|
||||
workers = h
|
||||
}
|
||||
if workers < 1 || h < 64 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
rowBytes := w * 4
|
||||
for row := y0; row < y1; row++ {
|
||||
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
|
||||
srcRow := src[row*srcStride : row*srcStride+rowBytes]
|
||||
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
|
||||
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
|
||||
for i, p := range srcU {
|
||||
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, h)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (h + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > h {
|
||||
y1 = h
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// MacPoller wraps CGCapturer in a continuous capture loop.
|
||||
type MacPoller struct {
|
||||
mu sync.Mutex
|
||||
frame *image.RGBA
|
||||
w, h int
|
||||
done chan struct{}
|
||||
// wake shortens the init-retry backoff when a client is trying to connect,
|
||||
// so granting Screen Recording mid-session takes effect immediately.
|
||||
wake chan struct{}
|
||||
}
|
||||
|
||||
// NewMacPoller creates a capturer that continuously grabs the macOS display.
|
||||
func NewMacPoller() *MacPoller {
|
||||
p := &MacPoller{
|
||||
done: make(chan struct{}),
|
||||
wake: make(chan struct{}, 1),
|
||||
}
|
||||
go p.loop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Wake pokes the init-retry loop so it doesn't wait out the full backoff
|
||||
// before trying again. Safe to call from any goroutine; extra calls while a
|
||||
// wake is pending are dropped.
|
||||
func (p *MacPoller) Wake() {
|
||||
select {
|
||||
case p.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops the capture loop.
|
||||
func (p *MacPoller) Close() {
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (p *MacPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height.
|
||||
func (p *MacPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture returns the most recent frame.
|
||||
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
img := p.frame
|
||||
p.mu.Unlock()
|
||||
if img != nil {
|
||||
return img, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no frame available yet")
|
||||
}
|
||||
|
||||
func (p *MacPoller) loop() {
|
||||
var capturer *CGCapturer
|
||||
var initFails int
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if capturer == nil {
|
||||
var err error
|
||||
capturer, err = NewCGCapturer()
|
||||
if err != nil {
|
||||
initFails++
|
||||
// Retry forever with backoff: the user may grant Screen
|
||||
// Recording after the server started, and we need to pick it
|
||||
// up whenever that happens.
|
||||
delay := 2 * time.Second
|
||||
if initFails > 15 {
|
||||
delay = 30 * time.Second
|
||||
} else if initFails > 5 {
|
||||
delay = 10 * time.Second
|
||||
}
|
||||
if initFails == 1 || initFails%10 == 0 {
|
||||
log.Warnf("macOS capturer: %v (attempt %d, retrying every %s)", err, initFails, delay)
|
||||
} else {
|
||||
log.Debugf("macOS capturer: %v (attempt %d)", err, initFails)
|
||||
}
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-p.wake:
|
||||
// Client is trying to connect, retry now.
|
||||
case <-time.After(delay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
initFails = 0
|
||||
p.mu.Lock()
|
||||
p.w, p.h = capturer.Width(), capturer.Height()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
img, err := capturer.Capture()
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(33 * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("macOS capture: %v", err)
|
||||
capturer = nil
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.frame = img
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||
@@ -1,99 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/kirides/go-d3d/d3d11"
|
||||
"github.com/kirides/go-d3d/outputduplication"
|
||||
)
|
||||
|
||||
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||
// Only works from the interactive user session, not Session 0.
|
||||
//
|
||||
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||
// output buffer and hand it out. Alternating between two output buffers
|
||||
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||
type dxgiCapturer struct {
|
||||
dup *outputduplication.OutputDuplicator
|
||||
device *d3d11.ID3D11Device
|
||||
ctx *d3d11.ID3D11DeviceContext
|
||||
img *image.RGBA
|
||||
out [2]*image.RGBA
|
||||
outIdx int
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||
}
|
||||
|
||||
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||
if err != nil {
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||
}
|
||||
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
dup.Release()
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
|
||||
rect := image.Rect(0, 0, w, h)
|
||||
c := &dxgiCapturer{
|
||||
dup: dup,
|
||||
device: device,
|
||||
ctx: deviceCtx,
|
||||
img: image.NewRGBA(rect),
|
||||
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||
width: w,
|
||||
height: h,
|
||||
}
|
||||
|
||||
// Grab the initial frame with a longer timeout to ensure we have
|
||||
// a valid image before returning.
|
||||
_ = dup.GetImage(c.img, 2000)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||
err := c.dup.GetImage(c.img, 100)
|
||||
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||
out := c.out[c.outIdx]
|
||||
c.outIdx ^= 1
|
||||
copy(out.Pix, c.img.Pix)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) close() {
|
||||
if c.dup != nil {
|
||||
c.dup.Release()
|
||||
c.dup = nil
|
||||
}
|
||||
if c.ctx != nil {
|
||||
c.ctx.Release()
|
||||
c.ctx = nil
|
||||
}
|
||||
if c.device != nil {
|
||||
c.device.Release()
|
||||
c.device = nil
|
||||
}
|
||||
}
|
||||
@@ -1,461 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||
|
||||
procGetDC = user32.NewProc("GetDC")
|
||||
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||
procSelectObject = gdi32.NewProc("SelectObject")
|
||||
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||
procBitBlt = gdi32.NewProc("BitBlt")
|
||||
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||
|
||||
// Desktop switching for service/Session 0 capture.
|
||||
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||
)
|
||||
|
||||
const uoiName = 2
|
||||
|
||||
const (
|
||||
smCxScreen = 0
|
||||
smCyScreen = 1
|
||||
srccopy = 0x00CC0020
|
||||
dibRgbColors = 0
|
||||
)
|
||||
|
||||
type bitmapInfoHeader struct {
|
||||
Size uint32
|
||||
Width int32
|
||||
Height int32
|
||||
Planes uint16
|
||||
BitCount uint16
|
||||
Compression uint32
|
||||
SizeImage uint32
|
||||
XPelsPerMeter int32
|
||||
YPelsPerMeter int32
|
||||
ClrUsed uint32
|
||||
ClrImportant uint32
|
||||
}
|
||||
|
||||
type bitmapInfo struct {
|
||||
Header bitmapInfoHeader
|
||||
}
|
||||
|
||||
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||
// the interactive window station. This is required for a SYSTEM service in
|
||||
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||
func setupInteractiveWindowStation() error {
|
||||
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||
}
|
||||
hWinSta, _, err := procOpenWindowStation.Call(
|
||||
uintptr(unsafe.Pointer(name)),
|
||||
0,
|
||||
uintptr(windows.MAXIMUM_ALLOWED),
|
||||
)
|
||||
if hWinSta == 0 {
|
||||
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||
}
|
||||
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||
if r == 0 {
|
||||
procCloseWindowStation.Call(hWinSta)
|
||||
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||
}
|
||||
log.Info("process window station set to WinSta0 (interactive)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func screenSize() (int, int) {
|
||||
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||
return int(w), int(h)
|
||||
}
|
||||
|
||||
func getDesktopName(hDesk uintptr) string {
|
||||
var buf [256]uint16
|
||||
var needed uint32
|
||||
procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||
uintptr(unsafe.Pointer(&needed)))
|
||||
return windows.UTF16ToString(buf[:])
|
||||
}
|
||||
|
||||
// switchToInputDesktop opens the desktop currently receiving user input
|
||||
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||
func switchToInputDesktop() (bool, string) {
|
||||
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||
if hDesk == 0 {
|
||||
return false, ""
|
||||
}
|
||||
name := getDesktopName(hDesk)
|
||||
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||
procCloseDesktop.Call(hDesk)
|
||||
return ret != 0, name
|
||||
}
|
||||
|
||||
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||
type gdiCapturer struct {
|
||||
mu sync.Mutex
|
||||
width int
|
||||
height int
|
||||
|
||||
// Pre-allocated GDI resources, reused across captures.
|
||||
memDC uintptr
|
||||
bmp uintptr
|
||||
bits uintptr
|
||||
}
|
||||
|
||||
func newGDICapturer() (*gdiCapturer, error) {
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
c := &gdiCapturer{width: w, height: h}
|
||||
if err := c.allocGDI(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||
func (c *gdiCapturer) allocGDI() error {
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer procReleaseDC.Call(0, screenDC)
|
||||
|
||||
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||
if memDC == 0 {
|
||||
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||
}
|
||||
|
||||
bi := bitmapInfo{
|
||||
Header: bitmapInfoHeader{
|
||||
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||
Width: int32(c.width),
|
||||
Height: -int32(c.height), // negative = top-down DIB
|
||||
Planes: 1,
|
||||
BitCount: 32,
|
||||
},
|
||||
}
|
||||
|
||||
var bits uintptr
|
||||
bmp, _, _ := procCreateDIBSection.Call(
|
||||
screenDC,
|
||||
uintptr(unsafe.Pointer(&bi)),
|
||||
dibRgbColors,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
0, 0,
|
||||
)
|
||||
if bmp == 0 || bits == 0 {
|
||||
procDeleteDC.Call(memDC)
|
||||
return fmt.Errorf("CreateDIBSection returned 0")
|
||||
}
|
||||
|
||||
procSelectObject.Call(memDC, bmp)
|
||||
|
||||
c.memDC = memDC
|
||||
c.bmp = bmp
|
||||
c.bits = bits
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||
|
||||
// freeGDI releases pre-allocated GDI resources.
|
||||
func (c *gdiCapturer) freeGDI() {
|
||||
if c.bmp != 0 {
|
||||
procDeleteObject.Call(c.bmp)
|
||||
c.bmp = 0
|
||||
}
|
||||
if c.memDC != 0 {
|
||||
procDeleteDC.Call(c.memDC)
|
||||
c.memDC = 0
|
||||
}
|
||||
c.bits = 0
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.memDC == 0 {
|
||||
return nil, fmt.Errorf("GDI resources not allocated")
|
||||
}
|
||||
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return nil, fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer procReleaseDC.Call(0, screenDC)
|
||||
|
||||
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||
screenDC, 0, 0, srccopy)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("BitBlt returned 0")
|
||||
}
|
||||
|
||||
n := c.width * c.height * 4
|
||||
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||
|
||||
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||
// per pixel instead of three separate byte assignments).
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||
pix := img.Pix
|
||||
copy(pix, raw)
|
||||
swizzleBGRAtoRGBA(pix)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||
// captures frames, which are retrieved by the VNC session on demand.
|
||||
// Capture pauses automatically when no clients are connected.
|
||||
type DesktopCapturer struct {
|
||||
mu sync.Mutex
|
||||
frame *image.RGBA
|
||||
w, h int
|
||||
|
||||
// clients tracks the number of active VNC sessions. When zero, the
|
||||
// capture loop idles instead of grabbing frames.
|
||||
clients atomic.Int32
|
||||
|
||||
// wake is signaled when a client connects and the loop should resume.
|
||||
wake chan struct{}
|
||||
// done is closed when Close is called, terminating the capture loop.
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
|
||||
func NewDesktopCapturer() *DesktopCapturer {
|
||||
c := &DesktopCapturer{
|
||||
wake: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go c.loop()
|
||||
return c
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count, resuming capture if needed.
|
||||
func (c *DesktopCapturer) ClientConnect() {
|
||||
c.clients.Add(1)
|
||||
select {
|
||||
case c.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count.
|
||||
func (c *DesktopCapturer) ClientDisconnect() {
|
||||
c.clients.Add(-1)
|
||||
}
|
||||
|
||||
// Close stops the capture loop and releases resources.
|
||||
func (c *DesktopCapturer) Close() {
|
||||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the current screen width.
|
||||
func (c *DesktopCapturer) Width() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.w
|
||||
}
|
||||
|
||||
// Height returns the current screen height.
|
||||
func (c *DesktopCapturer) Height() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.h
|
||||
}
|
||||
|
||||
// Capture returns the most recent desktop frame.
|
||||
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
img := c.frame
|
||||
c.mu.Unlock()
|
||||
if img != nil {
|
||||
return img, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no frame available yet")
|
||||
}
|
||||
|
||||
// waitForClient blocks until a client connects or the capturer is closed.
|
||||
func (c *DesktopCapturer) waitForClient() bool {
|
||||
if c.clients.Load() > 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-c.wake:
|
||||
return true
|
||||
case <-c.done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DesktopCapturer) loop() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
// When running as a Windows service (Session 0), we need to attach to the
|
||||
// interactive window station before OpenInputDesktop will succeed.
|
||||
if err := setupInteractiveWindowStation(); err != nil {
|
||||
log.Warnf("attach to interactive window station: %v", err)
|
||||
}
|
||||
|
||||
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
|
||||
defer frameTicker.Stop()
|
||||
|
||||
retryTimer := time.NewTimer(0)
|
||||
retryTimer.Stop()
|
||||
defer retryTimer.Stop()
|
||||
|
||||
type frameCapturer interface {
|
||||
capture() (*image.RGBA, error)
|
||||
close()
|
||||
}
|
||||
|
||||
var cap frameCapturer
|
||||
var desktopFails int
|
||||
var lastDesktop string
|
||||
|
||||
createCapturer := func() (frameCapturer, error) {
|
||||
dc, err := newDXGICapturer()
|
||||
if err == nil {
|
||||
log.Info("using DXGI Desktop Duplication for capture")
|
||||
return dc, nil
|
||||
}
|
||||
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
|
||||
gc, err := newGDICapturer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("using GDI BitBlt for capture")
|
||||
return gc, nil
|
||||
}
|
||||
|
||||
for {
|
||||
if !c.waitForClient() {
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No clients: release the capturer and wait.
|
||||
if c.clients.Load() <= 0 {
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
cap = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ok, desk := switchToInputDesktop()
|
||||
if !ok {
|
||||
desktopFails++
|
||||
if desktopFails == 1 || desktopFails%100 == 0 {
|
||||
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
|
||||
}
|
||||
retryTimer.Reset(100 * time.Millisecond)
|
||||
select {
|
||||
case <-retryTimer.C:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if desktopFails > 0 {
|
||||
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
|
||||
desktopFails = 0
|
||||
}
|
||||
if desk != lastDesktop {
|
||||
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
|
||||
lastDesktop = desk
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
}
|
||||
cap = nil
|
||||
}
|
||||
|
||||
if cap == nil {
|
||||
fc, err := createCapturer()
|
||||
if err != nil {
|
||||
log.Warnf("create capturer: %v", err)
|
||||
retryTimer.Reset(500 * time.Millisecond)
|
||||
select {
|
||||
case <-retryTimer.C:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
cap = fc
|
||||
w, h := screenSize()
|
||||
c.mu.Lock()
|
||||
c.w, c.h = w, h
|
||||
c.mu.Unlock()
|
||||
log.Infof("screen capturer ready: %dx%d", w, h)
|
||||
}
|
||||
|
||||
img, err := cap.capture()
|
||||
if err != nil {
|
||||
log.Debugf("capture: %v", err)
|
||||
cap.close()
|
||||
cap = nil
|
||||
retryTimer.Reset(100 * time.Millisecond)
|
||||
select {
|
||||
case <-retryTimer.C:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.frame = img
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-frameTicker.C:
|
||||
case <-c.done:
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,385 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
)
|
||||
|
||||
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||
type X11Capturer struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
screen *xproto.ScreenInfo
|
||||
w, h int
|
||||
shmID int
|
||||
shmAddr []byte
|
||||
shmSeg uint32 // shm.Seg
|
||||
useSHM bool
|
||||
}
|
||||
|
||||
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||
// environment variables if needed. This is required when running as a system
|
||||
// service where these vars aren't set.
|
||||
func detectX11Display() {
|
||||
if os.Getenv("DISPLAY") != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||
if detectX11FromProc() {
|
||||
return
|
||||
}
|
||||
if detectX11FromSockets() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||
func detectX11FromProc() bool {
|
||||
entries, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||
setDisplayEnv(display, auth)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||
func detectX11FromSockets() bool {
|
||||
entries, err := os.ReadDir("/tmp/.X11-unix")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Find the lowest display number.
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if len(name) < 2 || name[0] != 'X' {
|
||||
continue
|
||||
}
|
||||
display := ":" + name[1:]
|
||||
os.Setenv("DISPLAY", display)
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||
|
||||
// Try to find -auth from ps output.
|
||||
if auth := findXorgAuthFromPS(); auth != "" {
|
||||
os.Setenv("XAUTHORITY", auth)
|
||||
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||
func findXorgAuthFromPS() string {
|
||||
out, err := exec.Command("ps", "auxww").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
for i, f := range fields {
|
||||
if f == "-auth" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseXorgArgs(args []string) (display, auth string) {
|
||||
if len(args) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
base := args[0]
|
||||
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||
return "", ""
|
||||
}
|
||||
for i, arg := range args[1:] {
|
||||
if len(arg) > 0 && arg[0] == ':' {
|
||||
display = arg
|
||||
}
|
||||
if arg == "-auth" && i+2 < len(args) {
|
||||
auth = args[i+2]
|
||||
}
|
||||
}
|
||||
return display, auth
|
||||
}
|
||||
|
||||
func setDisplayEnv(display, auth string) {
|
||||
os.Setenv("DISPLAY", display)
|
||||
log.Infof("auto-detected DISPLAY=%s", display)
|
||||
if auth != "" {
|
||||
os.Setenv("XAUTHORITY", auth)
|
||||
log.Infof("auto-detected XAUTHORITY=%s", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func splitCmdline(data []byte) []string {
|
||||
var args []string
|
||||
for _, b := range splitNull(data) {
|
||||
if len(b) > 0 {
|
||||
args = append(args, string(b))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func splitNull(data []byte) [][]byte {
|
||||
var parts [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == 0 {
|
||||
parts = append(parts, data[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
parts = append(parts, data[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||
func NewX11Capturer(display string) (*X11Capturer, error) {
|
||||
detectX11Display()
|
||||
|
||||
if display == "" {
|
||||
display = os.Getenv("DISPLAY")
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
c := &X11Capturer{
|
||||
conn: conn,
|
||||
screen: &screen,
|
||||
w: int(screen.WidthInPixels),
|
||||
h: int(screen.HeightInPixels),
|
||||
}
|
||||
|
||||
if err := c.initSHM(); err != nil {
|
||||
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||
// the capturer falls back to GetImage.
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *X11Capturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *X11Capturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.useSHM {
|
||||
return c.captureSHM()
|
||||
}
|
||||
return c.captureGetImage()
|
||||
}
|
||||
|
||||
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
data := reply.Data
|
||||
n := c.w * c.h * 4
|
||||
if len(data) < n {
|
||||
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 4 {
|
||||
img.Pix[i+0] = data[i+2] // R
|
||||
img.Pix[i+1] = data[i+1] // G
|
||||
img.Pix[i+2] = data[i+0] // B
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (c *X11Capturer) Close() {
|
||||
c.closeSHM()
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
|
||||
// DesktopCapturer pattern from Windows.
|
||||
type X11Poller struct {
|
||||
mu sync.Mutex
|
||||
frame *image.RGBA
|
||||
w, h int
|
||||
display string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewX11Poller creates a capturer that continuously grabs the X11 display.
|
||||
func NewX11Poller(display string) *X11Poller {
|
||||
p := &X11Poller{
|
||||
display: display,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go p.loop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Close stops the capture loop.
|
||||
func (p *X11Poller) Close() {
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (p *X11Poller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height.
|
||||
func (p *X11Poller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture returns the most recent frame.
|
||||
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
img := p.frame
|
||||
p.mu.Unlock()
|
||||
if img != nil {
|
||||
return img, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no frame available yet")
|
||||
}
|
||||
|
||||
func (p *X11Poller) loop() {
|
||||
var capturer *X11Capturer
|
||||
var initFails int
|
||||
|
||||
defer func() {
|
||||
if capturer != nil {
|
||||
capturer.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if capturer == nil {
|
||||
var err error
|
||||
capturer, err = NewX11Capturer(p.display)
|
||||
if err != nil {
|
||||
initFails++
|
||||
if initFails <= maxCapturerRetries {
|
||||
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
|
||||
return
|
||||
}
|
||||
initFails = 0
|
||||
p.mu.Lock()
|
||||
p.w, p.h = capturer.Width(), capturer.Height()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
img, err := capturer.Capture()
|
||||
if err != nil {
|
||||
log.Debugf("X11 capture: %v", err)
|
||||
capturer.Close()
|
||||
capturer = nil
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.frame = img
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/jezek/xgb/shm"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
if err := shm.Init(c.conn); err != nil {
|
||||
return fmt.Errorf("init SHM extension: %w", err)
|
||||
}
|
||||
|
||||
size := c.w * c.h * 4
|
||||
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shmget: %w", err)
|
||||
}
|
||||
|
||||
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||
if err != nil {
|
||||
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||
return fmt.Errorf("shmat: %w", err)
|
||||
}
|
||||
|
||||
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||
|
||||
seg, err := shm.NewSegId(c.conn)
|
||||
if err != nil {
|
||||
unix.SysvShmDetach(addr)
|
||||
return fmt.Errorf("new SHM seg: %w", err)
|
||||
}
|
||||
|
||||
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||
unix.SysvShmDetach(addr)
|
||||
return fmt.Errorf("SHM attach to X: %w", err)
|
||||
}
|
||||
|
||||
c.shmID = id
|
||||
c.shmAddr = addr
|
||||
c.shmSeg = uint32(seg)
|
||||
c.useSHM = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||
|
||||
_, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SHM GetImage: %w", err)
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
n := c.w * c.h * 4
|
||||
|
||||
for i := 0; i < n; i += 4 {
|
||||
img.Pix[i+0] = c.shmAddr[i+2] // R
|
||||
img.Pix[i+1] = c.shmAddr[i+1] // G
|
||||
img.Pix[i+2] = c.shmAddr[i+0] // B
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
if c.useSHM {
|
||||
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||
unix.SysvShmDetach(c.shmAddr)
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
return fmt.Errorf("SysV SHM not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {}
|
||||
@@ -1,151 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"crypto/sha256"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
const (
|
||||
aesKeySize = 32 // AES-256
|
||||
gcmNonceSize = 12
|
||||
)
|
||||
|
||||
// recCrypto holds per-session encryption state.
|
||||
type recCrypto struct {
|
||||
gcm cipher.AEAD
|
||||
frameCounter uint64
|
||||
// ephemeralPub is stored in the recording header so the admin can derive the same key.
|
||||
ephemeralPub []byte
|
||||
}
|
||||
|
||||
// newRecCrypto sets up encryption for a new recording session.
|
||||
// adminPubKeyB64 is the base64-encoded X25519 public key from management settings.
|
||||
func newRecCrypto(adminPubKeyB64 string) (*recCrypto, error) {
|
||||
adminPubBytes, err := base64.StdEncoding.DecodeString(adminPubKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode admin public key: %w", err)
|
||||
}
|
||||
|
||||
adminPub, err := ecdh.X25519().NewPublicKey(adminPubBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse admin X25519 public key: %w", err)
|
||||
}
|
||||
|
||||
// Generate ephemeral keypair
|
||||
ephemeral, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ephemeral key: %w", err)
|
||||
}
|
||||
|
||||
// ECDH shared secret
|
||||
shared, err := ephemeral.ECDH(adminPub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ECDH: %w", err)
|
||||
}
|
||||
|
||||
// Derive AES-256 key via HKDF
|
||||
aesKey, err := deriveKey(shared, ephemeral.PublicKey().Bytes())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GCM: %w", err)
|
||||
}
|
||||
|
||||
return &recCrypto{
|
||||
gcm: gcm,
|
||||
ephemeralPub: ephemeral.PublicKey().Bytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encrypt encrypts plaintext using a counter-based nonce. Each call increments the counter.
|
||||
func (c *recCrypto) encrypt(plaintext []byte) []byte {
|
||||
nonce := make([]byte, gcmNonceSize)
|
||||
binary.LittleEndian.PutUint64(nonce, c.frameCounter)
|
||||
c.frameCounter++
|
||||
return c.gcm.Seal(nil, nonce, plaintext, nil)
|
||||
}
|
||||
|
||||
// DecryptRecording creates a decryptor from the admin's private key and the ephemeral public key from the header.
|
||||
func DecryptRecording(adminPrivKeyB64 string, ephemeralPubB64 string) (*recDecryptor, error) {
|
||||
adminPrivBytes, err := base64.StdEncoding.DecodeString(adminPrivKeyB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode admin private key: %w", err)
|
||||
}
|
||||
|
||||
adminPriv, err := ecdh.X25519().NewPrivateKey(adminPrivBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse admin X25519 private key: %w", err)
|
||||
}
|
||||
|
||||
ephPubBytes, err := base64.StdEncoding.DecodeString(ephemeralPubB64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode ephemeral public key: %w", err)
|
||||
}
|
||||
|
||||
ephPub, err := ecdh.X25519().NewPublicKey(ephPubBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ephemeral public key: %w", err)
|
||||
}
|
||||
|
||||
shared, err := adminPriv.ECDH(ephPub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ECDH: %w", err)
|
||||
}
|
||||
|
||||
aesKey, err := deriveKey(shared, ephPubBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create GCM: %w", err)
|
||||
}
|
||||
|
||||
return &recDecryptor{gcm: gcm}, nil
|
||||
}
|
||||
|
||||
type recDecryptor struct {
|
||||
gcm cipher.AEAD
|
||||
frameCounter uint64
|
||||
}
|
||||
|
||||
// Decrypt decrypts a frame. Must be called in the same order as encryption.
|
||||
func (d *recDecryptor) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||
nonce := make([]byte, gcmNonceSize)
|
||||
binary.LittleEndian.PutUint64(nonce, d.frameCounter)
|
||||
d.frameCounter++
|
||||
return d.gcm.Open(nil, nonce, ciphertext, nil)
|
||||
}
|
||||
|
||||
func deriveKey(shared, ephemeralPub []byte) ([]byte, error) {
|
||||
hkdfReader := hkdf.New(sha256.New, shared, ephemeralPub, []byte("netbird-recording"))
|
||||
key := make([]byte, aesKeySize)
|
||||
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCryptoRoundtrip(t *testing.T) {
|
||||
// Generate admin keypair
|
||||
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||
|
||||
// Create encryptor (recording side)
|
||||
enc, err := newRecCrypto(adminPubB64)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, enc.ephemeralPub, 32)
|
||||
|
||||
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||
|
||||
// Encrypt some frames
|
||||
plaintext1 := []byte("frame data one - PNG bytes would go here")
|
||||
plaintext2 := []byte("frame data two - different content")
|
||||
plaintext3 := make([]byte, 1024*100) // 100KB frame
|
||||
rand.Read(plaintext3)
|
||||
|
||||
ct1 := enc.encrypt(plaintext1)
|
||||
ct2 := enc.encrypt(plaintext2)
|
||||
ct3 := enc.encrypt(plaintext3)
|
||||
|
||||
// Ciphertext should differ from plaintext
|
||||
assert.NotEqual(t, plaintext1, ct1)
|
||||
// Ciphertext is larger (GCM tag overhead)
|
||||
assert.Greater(t, len(ct1), len(plaintext1))
|
||||
|
||||
// Create decryptor (playback side)
|
||||
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decrypt in same order
|
||||
got1, err := dec.Decrypt(ct1)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext1, got1)
|
||||
|
||||
got2, err := dec.Decrypt(ct2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext2, got2)
|
||||
|
||||
got3, err := dec.Decrypt(ct3)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, plaintext3, got3)
|
||||
}
|
||||
|
||||
func TestCryptoWrongKey(t *testing.T) {
|
||||
// Admin key
|
||||
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||
|
||||
// Encrypt with admin's public key
|
||||
enc, err := newRecCrypto(adminPubB64)
|
||||
require.NoError(t, err)
|
||||
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||
|
||||
ct := enc.encrypt([]byte("secret frame data"))
|
||||
|
||||
// Try to decrypt with a different private key
|
||||
wrongPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
wrongPrivB64 := base64.StdEncoding.EncodeToString(wrongPriv.Bytes())
|
||||
|
||||
dec, err := DecryptRecording(wrongPrivB64, ephPubB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = dec.Decrypt(ct)
|
||||
assert.Error(t, err, "decryption with wrong key should fail")
|
||||
}
|
||||
|
||||
func TestCryptoInvalidKey(t *testing.T) {
|
||||
_, err := newRecCrypto("")
|
||||
assert.Error(t, err, "empty key should fail")
|
||||
|
||||
_, err = newRecCrypto("not-base64!!!")
|
||||
assert.Error(t, err, "invalid base64 should fail")
|
||||
|
||||
_, err = newRecCrypto(base64.StdEncoding.EncodeToString([]byte("too-short")))
|
||||
assert.Error(t, err, "wrong-length key should fail")
|
||||
|
||||
_, err = DecryptRecording("", "validbutirrelevant")
|
||||
assert.Error(t, err, "empty private key should fail")
|
||||
|
||||
_, err = DecryptRecording("not-base64!!!", base64.StdEncoding.EncodeToString(make([]byte, 32)))
|
||||
assert.Error(t, err, "invalid base64 private key should fail")
|
||||
}
|
||||
|
||||
func TestCryptoOutOfOrderFails(t *testing.T) {
|
||||
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||
|
||||
enc, err := newRecCrypto(adminPubB64)
|
||||
require.NoError(t, err)
|
||||
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||
|
||||
ct0 := enc.encrypt([]byte("frame 0"))
|
||||
ct1 := enc.encrypt([]byte("frame 1"))
|
||||
|
||||
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Skip frame 0, try to decrypt frame 1 first (wrong nonce)
|
||||
_, err = dec.Decrypt(ct1)
|
||||
assert.Error(t, err, "out-of-order decryption should fail due to nonce mismatch")
|
||||
|
||||
// But frame 0 with a fresh decryptor should work
|
||||
dec2, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||
require.NoError(t, err)
|
||||
got, err := dec2.Decrypt(ct0)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []byte("frame 0"), got)
|
||||
}
|
||||
@@ -1,540 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Core Graphics event constants.
|
||||
const (
|
||||
kCGEventSourceStateCombinedSessionState int32 = 0
|
||||
|
||||
kCGEventLeftMouseDown int32 = 1
|
||||
kCGEventLeftMouseUp int32 = 2
|
||||
kCGEventRightMouseDown int32 = 3
|
||||
kCGEventRightMouseUp int32 = 4
|
||||
kCGEventMouseMoved int32 = 5
|
||||
kCGEventLeftMouseDragged int32 = 6
|
||||
kCGEventRightMouseDragged int32 = 7
|
||||
kCGEventKeyDown int32 = 10
|
||||
kCGEventKeyUp int32 = 11
|
||||
kCGEventOtherMouseDown int32 = 25
|
||||
kCGEventOtherMouseUp int32 = 26
|
||||
|
||||
kCGMouseButtonLeft int32 = 0
|
||||
kCGMouseButtonRight int32 = 1
|
||||
kCGMouseButtonCenter int32 = 2
|
||||
|
||||
kCGHIDEventTap int32 = 0
|
||||
|
||||
// IOKit power management constants.
|
||||
kIOPMUserActiveLocal int32 = 0
|
||||
kIOPMAssertionLevelOn uint32 = 255
|
||||
kCFStringEncodingUTF8 uint32 = 0x08000100
|
||||
)
|
||||
|
||||
var darwinInputOnce sync.Once
|
||||
|
||||
var (
|
||||
cgEventSourceCreate func(int32) uintptr
|
||||
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
|
||||
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
|
||||
// purego can't handle array/struct types but individual float64s work.
|
||||
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
|
||||
cgEventPost func(int32, uintptr)
|
||||
|
||||
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
|
||||
cgEventCreateScrollWheelEventAddr uintptr
|
||||
|
||||
axIsProcessTrusted func() bool
|
||||
|
||||
// IOKit power-management bindings used to wake the display and inhibit
|
||||
// idle sleep while a VNC client is driving input.
|
||||
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
|
||||
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
|
||||
iopmAssertionRelease func(uint32) int32
|
||||
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
|
||||
|
||||
// Cached CFStrings for assertion name and idle-sleep type.
|
||||
pmAssertionNameCFStr uintptr
|
||||
pmPreventIdleDisplayCFStr uintptr
|
||||
|
||||
// Assertion IDs. userActivityID is reused across input events so repeated
|
||||
// calls refresh the same assertion rather than create new ones.
|
||||
pmMu sync.Mutex
|
||||
userActivityID uint32
|
||||
preventSleepID uint32
|
||||
preventSleepHeld bool
|
||||
|
||||
darwinInputReady bool
|
||||
darwinEventSource uintptr
|
||||
)
|
||||
|
||||
func initDarwinInput() {
|
||||
darwinInputOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics for input: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
|
||||
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
|
||||
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
|
||||
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
|
||||
|
||||
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
|
||||
if err == nil {
|
||||
cgEventCreateScrollWheelEventAddr = sym
|
||||
}
|
||||
|
||||
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
|
||||
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
|
||||
purego.RegisterFunc(&axIsProcessTrusted, sym)
|
||||
}
|
||||
}
|
||||
|
||||
initPowerAssertions()
|
||||
|
||||
darwinInputReady = true
|
||||
})
|
||||
}
|
||||
|
||||
func initPowerAssertions() {
|
||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load IOKit: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation for power assertions: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
|
||||
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
|
||||
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
|
||||
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
|
||||
|
||||
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
|
||||
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
|
||||
}
|
||||
|
||||
// wakeDisplay declares user activity so macOS treats the synthesized input as
|
||||
// real HID activity, waking the display if it is asleep. Called on every key
|
||||
// and pointer event; the kernel coalesces repeated calls cheaply.
|
||||
func wakeDisplay() {
|
||||
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
id := userActivityID
|
||||
pmMu.Unlock()
|
||||
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
|
||||
if r != 0 {
|
||||
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
userActivityID = id
|
||||
pmMu.Unlock()
|
||||
}
|
||||
|
||||
// holdPreventIdleSleep creates an assertion that keeps the display from going
|
||||
// idle-to-sleep while a VNC session is active. Safe to call repeatedly.
|
||||
func holdPreventIdleSleep() {
|
||||
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
if preventSleepHeld {
|
||||
return
|
||||
}
|
||||
var id uint32
|
||||
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
|
||||
if r != 0 {
|
||||
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
|
||||
return
|
||||
}
|
||||
preventSleepID = id
|
||||
preventSleepHeld = true
|
||||
}
|
||||
|
||||
// releasePreventIdleSleep drops the idle-sleep assertion.
|
||||
func releasePreventIdleSleep() {
|
||||
if iopmAssertionRelease == nil {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
if !preventSleepHeld {
|
||||
return
|
||||
}
|
||||
if r := iopmAssertionRelease(preventSleepID); r != 0 {
|
||||
log.Debugf("IOPMAssertionRelease returned %d", r)
|
||||
}
|
||||
preventSleepHeld = false
|
||||
preventSleepID = 0
|
||||
}
|
||||
|
||||
func ensureEventSource() uintptr {
|
||||
if darwinEventSource != 0 {
|
||||
return darwinEventSource
|
||||
}
|
||||
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
|
||||
return darwinEventSource
|
||||
}
|
||||
|
||||
// MacInputInjector injects keyboard and mouse events via Core Graphics.
|
||||
type MacInputInjector struct {
|
||||
lastButtons uint8
|
||||
pbcopyPath string
|
||||
pbpastePath string
|
||||
}
|
||||
|
||||
// NewMacInputInjector creates a macOS input injector.
|
||||
func NewMacInputInjector() (*MacInputInjector, error) {
|
||||
initDarwinInput()
|
||||
if !darwinInputReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available for input injection")
|
||||
}
|
||||
checkMacPermissions()
|
||||
|
||||
m := &MacInputInjector{}
|
||||
if path, err := exec.LookPath("pbcopy"); err == nil {
|
||||
m.pbcopyPath = path
|
||||
}
|
||||
if path, err := exec.LookPath("pbpaste"); err == nil {
|
||||
m.pbpastePath = path
|
||||
}
|
||||
if m.pbcopyPath == "" || m.pbpastePath == "" {
|
||||
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
|
||||
}
|
||||
|
||||
holdPreventIdleSleep()
|
||||
|
||||
log.Info("macOS input injector ready")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// checkMacPermissions warns and opens the Privacy pane if Accessibility is
|
||||
// missing. Uses AXIsProcessTrusted which returns immediately; the previous
|
||||
// osascript probe blocked for 120s (AppleEvent timeout) when access was
|
||||
// denied, which delayed VNC server startup past client deadlines.
|
||||
func checkMacPermissions() {
|
||||
if axIsProcessTrusted != nil && !axIsProcessTrusted() {
|
||||
openPrivacyPane("Privacy_Accessibility")
|
||||
log.Warn("Accessibility permission not granted. Input injection will not work. " +
|
||||
"Opened System Settings > Privacy & Security > Accessibility; enable netbird.")
|
||||
}
|
||||
|
||||
log.Info("Screen Recording permission is required for screen capture. " +
|
||||
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
|
||||
}
|
||||
|
||||
// openPrivacyPane opens the given Privacy pane in System Settings so the user
|
||||
// can toggle the permission without navigating manually.
|
||||
func openPrivacyPane(pane string) {
|
||||
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
|
||||
if err := exec.Command("open", url).Start(); err != nil {
|
||||
log.Debugf("open privacy pane %s: %v", pane, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release.
|
||||
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
keycode := keysymToMacKeycode(keysym)
|
||||
if keycode == 0xFFFF {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||
wakeDisplay()
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Framebuffer is in physical pixels (Retina). CGEventCreateMouseEvent
|
||||
// expects logical points, so scale down by the display's pixel/point ratio.
|
||||
x := float64(px)
|
||||
y := float64(py)
|
||||
if cgDisplayPixelsWide != nil && cgMainDisplayID != nil {
|
||||
displayID := cgMainDisplayID()
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
if logicalW > 0 && logicalH > 0 {
|
||||
x = float64(px) * float64(logicalW) / float64(serverW)
|
||||
y = float64(py) * float64(logicalH) / float64(serverH)
|
||||
}
|
||||
}
|
||||
leftDown := buttonMask&0x01 != 0
|
||||
rightDown := buttonMask&0x04 != 0
|
||||
middleDown := buttonMask&0x02 != 0
|
||||
scrollUp := buttonMask&0x08 != 0
|
||||
scrollDown := buttonMask&0x10 != 0
|
||||
|
||||
wasLeft := m.lastButtons&0x01 != 0
|
||||
wasRight := m.lastButtons&0x04 != 0
|
||||
wasMiddle := m.lastButtons&0x02 != 0
|
||||
|
||||
if leftDown {
|
||||
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
|
||||
} else if rightDown {
|
||||
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
|
||||
} else {
|
||||
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
|
||||
}
|
||||
|
||||
if leftDown && !wasLeft {
|
||||
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
|
||||
} else if !leftDown && wasLeft {
|
||||
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
|
||||
}
|
||||
if rightDown && !wasRight {
|
||||
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
|
||||
} else if !rightDown && wasRight {
|
||||
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
|
||||
}
|
||||
if middleDown && !wasMiddle {
|
||||
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
|
||||
} else if !middleDown && wasMiddle {
|
||||
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
|
||||
}
|
||||
|
||||
if scrollUp {
|
||||
m.postScroll(src, 3)
|
||||
}
|
||||
if scrollDown {
|
||||
m.postScroll(src, -3)
|
||||
}
|
||||
|
||||
m.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
|
||||
if cgEventCreateMouseEvent == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
|
||||
if cgEventCreateScrollWheelEventAddr == 0 {
|
||||
return
|
||||
}
|
||||
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
|
||||
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
|
||||
// Variadic C function: pass args as uintptr via SyscallN.
|
||||
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
|
||||
src, 0, 1, uintptr(uint32(deltaY)))
|
||||
if r1 == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, r1)
|
||||
cfRelease(r1)
|
||||
}
|
||||
|
||||
// SetClipboard sets the macOS clipboard using pbcopy.
|
||||
func (m *MacInputInjector) SetClipboard(text string) {
|
||||
if m.pbcopyPath == "" {
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(m.pbcopyPath)
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Tracef("set clipboard via pbcopy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClipboard reads the macOS clipboard using pbpaste.
|
||||
func (m *MacInputInjector) GetClipboard() string {
|
||||
if m.pbpastePath == "" {
|
||||
return ""
|
||||
}
|
||||
out, err := exec.Command(m.pbpastePath).Output()
|
||||
if err != nil {
|
||||
log.Tracef("get clipboard via pbpaste: %v", err)
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Close releases the idle-sleep assertion held for the injector's lifetime.
|
||||
func (m *MacInputInjector) Close() {
|
||||
releasePreventIdleSleep()
|
||||
}
|
||||
|
||||
func keysymToMacKeycode(keysym uint32) uint16 {
|
||||
if keysym >= 0x61 && keysym <= 0x7a {
|
||||
return asciiToMacKey[keysym-0x61]
|
||||
}
|
||||
if keysym >= 0x41 && keysym <= 0x5a {
|
||||
return asciiToMacKey[keysym-0x41]
|
||||
}
|
||||
if keysym >= 0x30 && keysym <= 0x39 {
|
||||
return digitToMacKey[keysym-0x30]
|
||||
}
|
||||
if code, ok := specialKeyMap[keysym]; ok {
|
||||
return code
|
||||
}
|
||||
return 0xFFFF
|
||||
}
|
||||
|
||||
var asciiToMacKey = [26]uint16{
|
||||
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
|
||||
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
|
||||
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
|
||||
0x10, 0x06,
|
||||
}
|
||||
|
||||
var digitToMacKey = [10]uint16{
|
||||
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
|
||||
}
|
||||
|
||||
var specialKeyMap = map[uint32]uint16{
|
||||
// Whitespace and editing
|
||||
0x0020: 0x31, // space
|
||||
0xff08: 0x33, // BackSpace
|
||||
0xff09: 0x30, // Tab
|
||||
0xff0d: 0x24, // Return
|
||||
0xff1b: 0x35, // Escape
|
||||
0xffff: 0x75, // Delete (forward)
|
||||
|
||||
// Navigation
|
||||
0xff50: 0x73, // Home
|
||||
0xff51: 0x7B, // Left
|
||||
0xff52: 0x7E, // Up
|
||||
0xff53: 0x7C, // Right
|
||||
0xff54: 0x7D, // Down
|
||||
0xff55: 0x74, // Page_Up
|
||||
0xff56: 0x79, // Page_Down
|
||||
0xff57: 0x77, // End
|
||||
0xff63: 0x72, // Insert (Help on Mac)
|
||||
|
||||
// Modifiers
|
||||
0xffe1: 0x38, // Shift_L
|
||||
0xffe2: 0x3C, // Shift_R
|
||||
0xffe3: 0x3B, // Control_L
|
||||
0xffe4: 0x3E, // Control_R
|
||||
0xffe5: 0x39, // Caps_Lock
|
||||
0xffe9: 0x3A, // Alt_L (Option)
|
||||
0xffea: 0x3D, // Alt_R (Option)
|
||||
0xffe7: 0x37, // Meta_L (Command)
|
||||
0xffe8: 0x36, // Meta_R (Command)
|
||||
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
|
||||
0xffec: 0x36, // Super_R (Command)
|
||||
|
||||
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
|
||||
0xff7e: 0x3A, // Mode_switch -> Option
|
||||
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
|
||||
|
||||
// Function keys
|
||||
0xffbe: 0x7A, // F1
|
||||
0xffbf: 0x78, // F2
|
||||
0xffc0: 0x63, // F3
|
||||
0xffc1: 0x76, // F4
|
||||
0xffc2: 0x60, // F5
|
||||
0xffc3: 0x61, // F6
|
||||
0xffc4: 0x62, // F7
|
||||
0xffc5: 0x64, // F8
|
||||
0xffc6: 0x65, // F9
|
||||
0xffc7: 0x6D, // F10
|
||||
0xffc8: 0x67, // F11
|
||||
0xffc9: 0x6F, // F12
|
||||
0xffca: 0x69, // F13
|
||||
0xffcb: 0x6B, // F14
|
||||
0xffcc: 0x71, // F15
|
||||
0xffcd: 0x6A, // F16
|
||||
0xffce: 0x40, // F17
|
||||
0xffcf: 0x4F, // F18
|
||||
0xffd0: 0x50, // F19
|
||||
0xffd1: 0x5A, // F20
|
||||
|
||||
// Punctuation (US keyboard layout, keysym = ASCII code)
|
||||
0x002d: 0x1B, // minus -
|
||||
0x003d: 0x18, // equal =
|
||||
0x005b: 0x21, // bracketleft [
|
||||
0x005d: 0x1E, // bracketright ]
|
||||
0x005c: 0x2A, // backslash
|
||||
0x003b: 0x29, // semicolon ;
|
||||
0x0027: 0x27, // apostrophe '
|
||||
0x0060: 0x32, // grave `
|
||||
0x002c: 0x2B, // comma ,
|
||||
0x002e: 0x2F, // period .
|
||||
0x002f: 0x2C, // slash /
|
||||
|
||||
// Shifted punctuation (noVNC sends these as separate keysyms)
|
||||
0x005f: 0x1B, // underscore _ (shift+minus)
|
||||
0x002b: 0x18, // plus + (shift+equal)
|
||||
0x007b: 0x21, // braceleft { (shift+[)
|
||||
0x007d: 0x1E, // braceright } (shift+])
|
||||
0x007c: 0x2A, // bar | (shift+\)
|
||||
0x003a: 0x29, // colon : (shift+;)
|
||||
0x0022: 0x27, // quotedbl " (shift+')
|
||||
0x007e: 0x32, // tilde ~ (shift+`)
|
||||
0x003c: 0x2B, // less < (shift+,)
|
||||
0x003e: 0x2F, // greater > (shift+.)
|
||||
0x003f: 0x2C, // question ? (shift+/)
|
||||
0x0021: 0x12, // exclam ! (shift+1)
|
||||
0x0040: 0x13, // at @ (shift+2)
|
||||
0x0023: 0x14, // numbersign # (shift+3)
|
||||
0x0024: 0x15, // dollar $ (shift+4)
|
||||
0x0025: 0x17, // percent % (shift+5)
|
||||
0x005e: 0x16, // asciicircum ^ (shift+6)
|
||||
0x0026: 0x1A, // ampersand & (shift+7)
|
||||
0x002a: 0x1C, // asterisk * (shift+8)
|
||||
0x0028: 0x19, // parenleft ( (shift+9)
|
||||
0x0029: 0x1D, // parenright ) (shift+0)
|
||||
|
||||
// Numpad
|
||||
0xffb0: 0x52, // KP_0
|
||||
0xffb1: 0x53, // KP_1
|
||||
0xffb2: 0x54, // KP_2
|
||||
0xffb3: 0x55, // KP_3
|
||||
0xffb4: 0x56, // KP_4
|
||||
0xffb5: 0x57, // KP_5
|
||||
0xffb6: 0x58, // KP_6
|
||||
0xffb7: 0x59, // KP_7
|
||||
0xffb8: 0x5B, // KP_8
|
||||
0xffb9: 0x5C, // KP_9
|
||||
0xffae: 0x41, // KP_Decimal
|
||||
0xffaa: 0x43, // KP_Multiply
|
||||
0xffab: 0x45, // KP_Add
|
||||
0xffad: 0x4E, // KP_Subtract
|
||||
0xffaf: 0x4B, // KP_Divide
|
||||
0xff8d: 0x4C, // KP_Enter
|
||||
0xffbd: 0x51, // KP_Equal
|
||||
}
|
||||
|
||||
var _ InputInjector = (*MacInputInjector)(nil)
|
||||
@@ -1,398 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procOpenEventW = kernel32.NewProc("OpenEventW")
|
||||
procSendInput = user32.NewProc("SendInput")
|
||||
procVkKeyScanA = user32.NewProc("VkKeyScanA")
|
||||
)
|
||||
|
||||
const eventModifyState = 0x0002
|
||||
|
||||
const (
|
||||
inputMouse = 0
|
||||
inputKeyboard = 1
|
||||
|
||||
mouseeventfMove = 0x0001
|
||||
mouseeventfLeftDown = 0x0002
|
||||
mouseeventfLeftUp = 0x0004
|
||||
mouseeventfRightDown = 0x0008
|
||||
mouseeventfRightUp = 0x0010
|
||||
mouseeventfMiddleDown = 0x0020
|
||||
mouseeventfMiddleUp = 0x0040
|
||||
mouseeventfWheel = 0x0800
|
||||
mouseeventfAbsolute = 0x8000
|
||||
|
||||
wheelDelta = 120
|
||||
|
||||
keyeventfKeyUp = 0x0002
|
||||
keyeventfScanCode = 0x0008
|
||||
)
|
||||
|
||||
type mouseInput struct {
|
||||
Dx int32
|
||||
Dy int32
|
||||
MouseData uint32
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
}
|
||||
|
||||
type keybdInput struct {
|
||||
WVk uint16
|
||||
WScan uint16
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
_ [8]byte
|
||||
}
|
||||
|
||||
type inputUnion [32]byte
|
||||
|
||||
type winInput struct {
|
||||
Type uint32
|
||||
_ [4]byte
|
||||
Data inputUnion
|
||||
}
|
||||
|
||||
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
|
||||
mi := mouseInput{
|
||||
Dx: dx,
|
||||
Dy: dy,
|
||||
MouseData: mouseData,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputMouse}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
|
||||
ki := keybdInput{
|
||||
WVk: vk,
|
||||
WScan: scanCode,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputKeyboard}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
|
||||
}
|
||||
}
|
||||
|
||||
const sasEventName = `Global\NetBirdVNC_SAS`
|
||||
|
||||
type inputCmd struct {
|
||||
isKey bool
|
||||
keysym uint32
|
||||
down bool
|
||||
buttonMask uint8
|
||||
x, y int
|
||||
serverW int
|
||||
serverH int
|
||||
}
|
||||
|
||||
// WindowsInputInjector delivers input events from a dedicated OS thread that
|
||||
// calls switchToInputDesktop before each injection. SendInput targets the
|
||||
// calling thread's desktop, so the injection thread must be on the same
|
||||
// desktop the user sees.
|
||||
type WindowsInputInjector struct {
|
||||
ch chan inputCmd
|
||||
prevButtonMask uint8
|
||||
ctrlDown bool
|
||||
altDown bool
|
||||
}
|
||||
|
||||
// NewWindowsInputInjector creates a desktop-aware input injector.
|
||||
func NewWindowsInputInjector() *WindowsInputInjector {
|
||||
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
|
||||
go w.loop()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) loop() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
for cmd := range w.ch {
|
||||
// Switch to the current input desktop so SendInput reaches the right target.
|
||||
switchToInputDesktop()
|
||||
|
||||
if cmd.isKey {
|
||||
w.doInjectKey(cmd.keysym, cmd.down)
|
||||
} else {
|
||||
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey queues a key event for injection on the input desktop thread.
|
||||
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
|
||||
}
|
||||
|
||||
// InjectPointer queues a pointer event for injection on the input desktop thread.
|
||||
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
|
||||
switch keysym {
|
||||
case 0xffe3, 0xffe4:
|
||||
w.ctrlDown = down
|
||||
case 0xffe9, 0xffea:
|
||||
w.altDown = down
|
||||
}
|
||||
|
||||
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||
signalSAS()
|
||||
return
|
||||
}
|
||||
|
||||
vk, _, extended := keysym2VK(keysym)
|
||||
if vk == 0 {
|
||||
return
|
||||
}
|
||||
var flags uint32
|
||||
if !down {
|
||||
flags |= keyeventfKeyUp
|
||||
}
|
||||
if extended {
|
||||
flags |= keyeventfScanCode
|
||||
}
|
||||
sendKeyInput(vk, 0, flags)
|
||||
}
|
||||
|
||||
// signalSAS signals the SAS named event. A listener in Session 0
|
||||
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
|
||||
func signalSAS() {
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS UTF16: %v", err)
|
||||
return
|
||||
}
|
||||
h, _, lerr := procOpenEventW.Call(
|
||||
uintptr(eventModifyState),
|
||||
0,
|
||||
uintptr(unsafe.Pointer(namePtr)),
|
||||
)
|
||||
if h == 0 {
|
||||
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
|
||||
return
|
||||
}
|
||||
ev := windows.Handle(h)
|
||||
defer windows.CloseHandle(ev)
|
||||
if err := windows.SetEvent(ev); err != nil {
|
||||
log.Warnf("SetEvent SAS: %v", err)
|
||||
} else {
|
||||
log.Info("SAS event signaled")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
absX := int32(x * 65535 / serverW)
|
||||
absY := int32(y * 65535 / serverH)
|
||||
|
||||
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
|
||||
|
||||
changed := buttonMask ^ w.prevButtonMask
|
||||
w.prevButtonMask = buttonMask
|
||||
|
||||
type btnMap struct {
|
||||
bit uint8
|
||||
down uint32
|
||||
up uint32
|
||||
}
|
||||
buttons := [...]btnMap{
|
||||
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
|
||||
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
|
||||
{0x04, mouseeventfRightDown, mouseeventfRightUp},
|
||||
}
|
||||
for _, b := range buttons {
|
||||
if changed&b.bit == 0 {
|
||||
continue
|
||||
}
|
||||
var flags uint32
|
||||
if buttonMask&b.bit != 0 {
|
||||
flags = b.down
|
||||
} else {
|
||||
flags = b.up
|
||||
}
|
||||
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
|
||||
}
|
||||
|
||||
negWheelDelta := ^uint32(wheelDelta - 1)
|
||||
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
|
||||
}
|
||||
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
|
||||
}
|
||||
}
|
||||
|
||||
// keysym2VK converts an X11 keysym to a Windows virtual key code.
|
||||
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
|
||||
if keysym >= 0x20 && keysym <= 0x7e {
|
||||
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
|
||||
vk = uint16(r & 0xff)
|
||||
return
|
||||
}
|
||||
|
||||
if keysym >= 0xffbe && keysym <= 0xffc9 {
|
||||
vk = uint16(0x70 + keysym - 0xffbe)
|
||||
return
|
||||
}
|
||||
|
||||
switch keysym {
|
||||
case 0xff08:
|
||||
vk = 0x08 // Backspace
|
||||
case 0xff09:
|
||||
vk = 0x09 // Tab
|
||||
case 0xff0d:
|
||||
vk = 0x0d // Return
|
||||
case 0xff1b:
|
||||
vk = 0x1b // Escape
|
||||
case 0xff63:
|
||||
vk, extended = 0x2d, true // Insert
|
||||
case 0xff9f, 0xffff:
|
||||
vk, extended = 0x2e, true // Delete
|
||||
case 0xff50:
|
||||
vk, extended = 0x24, true // Home
|
||||
case 0xff57:
|
||||
vk, extended = 0x23, true // End
|
||||
case 0xff55:
|
||||
vk, extended = 0x21, true // PageUp
|
||||
case 0xff56:
|
||||
vk, extended = 0x22, true // PageDown
|
||||
case 0xff51:
|
||||
vk, extended = 0x25, true // Left
|
||||
case 0xff52:
|
||||
vk, extended = 0x26, true // Up
|
||||
case 0xff53:
|
||||
vk, extended = 0x27, true // Right
|
||||
case 0xff54:
|
||||
vk, extended = 0x28, true // Down
|
||||
case 0xffe1, 0xffe2:
|
||||
vk = 0x10 // Shift
|
||||
case 0xffe3, 0xffe4:
|
||||
vk = 0x11 // Control
|
||||
case 0xffe9, 0xffea:
|
||||
vk = 0x12 // Alt
|
||||
case 0xffe5:
|
||||
vk = 0x14 // CapsLock
|
||||
case 0xffe7, 0xffeb:
|
||||
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
|
||||
case 0xffe8, 0xffec:
|
||||
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
|
||||
case 0xff61:
|
||||
vk = 0x2c // PrintScreen
|
||||
case 0xff13:
|
||||
vk = 0x13 // Pause
|
||||
case 0xff14:
|
||||
vk = 0x91 // ScrollLock
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
procOpenClipboard = user32.NewProc("OpenClipboard")
|
||||
procCloseClipboard = user32.NewProc("CloseClipboard")
|
||||
procEmptyClipboard = user32.NewProc("EmptyClipboard")
|
||||
procSetClipboardData = user32.NewProc("SetClipboardData")
|
||||
procGetClipboardData = user32.NewProc("GetClipboardData")
|
||||
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
|
||||
|
||||
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
|
||||
procGlobalLock = kernel32.NewProc("GlobalLock")
|
||||
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
|
||||
)
|
||||
|
||||
const (
|
||||
cfUnicodeText = 13
|
||||
gmemMoveable = 0x0002
|
||||
)
|
||||
|
||||
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
|
||||
func (w *WindowsInputInjector) SetClipboard(text string) {
|
||||
utf16, err := windows.UTF16FromString(text)
|
||||
if err != nil {
|
||||
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
size := uintptr(len(utf16) * 2)
|
||||
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
|
||||
if hMem == 0 {
|
||||
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hMem)
|
||||
if ptr == 0 {
|
||||
log.Tracef("GlobalLock for clipboard: lock returned nil")
|
||||
return
|
||||
}
|
||||
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
|
||||
procGlobalUnlock.Call(hMem)
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard: %v", lerr)
|
||||
return
|
||||
}
|
||||
defer procCloseClipboard.Call()
|
||||
|
||||
procEmptyClipboard.Call()
|
||||
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
|
||||
if r == 0 {
|
||||
log.Tracef("SetClipboardData: %v", lerr)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClipboard reads the Windows clipboard as UTF-8 text.
|
||||
func (w *WindowsInputInjector) GetClipboard() string {
|
||||
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
|
||||
if r == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard for read: %v", lerr)
|
||||
return ""
|
||||
}
|
||||
defer procCloseClipboard.Call()
|
||||
|
||||
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
|
||||
if hData == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hData)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
defer procGlobalUnlock.Call(hData)
|
||||
|
||||
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
|
||||
}
|
||||
|
||||
var _ InputInjector = (*WindowsInputInjector)(nil)
|
||||
|
||||
var _ ScreenCapturer = (*DesktopCapturer)(nil)
|
||||
@@ -1,242 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
"github.com/jezek/xgb/xtest"
|
||||
)
|
||||
|
||||
// X11InputInjector injects keyboard and mouse events via the XTest extension.
|
||||
type X11InputInjector struct {
|
||||
conn *xgb.Conn
|
||||
root xproto.Window
|
||||
screen *xproto.ScreenInfo
|
||||
display string
|
||||
keysymMap map[uint32]byte
|
||||
lastButtons uint8
|
||||
clipboardTool string
|
||||
clipboardToolName string
|
||||
}
|
||||
|
||||
// NewX11InputInjector connects to the X11 display and initializes XTest.
|
||||
func NewX11InputInjector(display string) (*X11InputInjector, error) {
|
||||
detectX11Display()
|
||||
|
||||
if display == "" {
|
||||
display = os.Getenv("DISPLAY")
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
if err := xtest.Init(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("init XTest extension: %w", err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
inj := &X11InputInjector{
|
||||
conn: conn,
|
||||
root: screen.Root,
|
||||
screen: &screen,
|
||||
display: display,
|
||||
}
|
||||
inj.cacheKeyboardMapping()
|
||||
inj.resolveClipboardTool()
|
||||
|
||||
log.Infof("X11 input injector ready (display=%s)", display)
|
||||
return inj, nil
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
|
||||
keycode := x.keysymToKeycode(keysym)
|
||||
if keycode == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var eventType byte
|
||||
if down {
|
||||
eventType = xproto.KeyPress
|
||||
} else {
|
||||
eventType = xproto.KeyRelease
|
||||
}
|
||||
|
||||
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Scale to actual screen coordinates.
|
||||
screenW := int(x.screen.WidthInPixels)
|
||||
screenH := int(x.screen.HeightInPixels)
|
||||
absX := px * screenW / serverW
|
||||
absY := py * screenH / serverH
|
||||
|
||||
// Move pointer.
|
||||
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
|
||||
|
||||
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
|
||||
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
|
||||
// 4=scrollUp, 5=scrollDown.
|
||||
type btnMap struct {
|
||||
rfbBit uint8
|
||||
x11Btn byte
|
||||
}
|
||||
buttons := [...]btnMap{
|
||||
{0x01, 1}, // left
|
||||
{0x02, 2}, // middle
|
||||
{0x04, 3}, // right
|
||||
{0x08, 4}, // scroll up
|
||||
{0x10, 5}, // scroll down
|
||||
}
|
||||
|
||||
for _, b := range buttons {
|
||||
pressed := buttonMask&b.rfbBit != 0
|
||||
wasPressed := x.lastButtons&b.rfbBit != 0
|
||||
if b.x11Btn >= 4 {
|
||||
// Scroll: send press+release on each scroll event.
|
||||
if pressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
} else {
|
||||
if pressed && !wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
} else if !pressed && wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
x.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
|
||||
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
|
||||
func (x *X11InputInjector) cacheKeyboardMapping() {
|
||||
setup := xproto.Setup(x.conn)
|
||||
minKeycode := setup.MinKeycode
|
||||
maxKeycode := setup.MaxKeycode
|
||||
|
||||
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
|
||||
byte(maxKeycode-minKeycode+1)).Reply()
|
||||
if err != nil {
|
||||
log.Debugf("cache keyboard mapping: %v", err)
|
||||
x.keysymMap = make(map[uint32]byte)
|
||||
return
|
||||
}
|
||||
|
||||
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
|
||||
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
|
||||
for i := int(minKeycode); i <= int(maxKeycode); i++ {
|
||||
offset := (i - int(minKeycode)) * keysymsPerKeycode
|
||||
for j := 0; j < keysymsPerKeycode; j++ {
|
||||
ks := uint32(reply.Keysyms[offset+j])
|
||||
if ks != 0 {
|
||||
if _, exists := m[ks]; !exists {
|
||||
m[ks] = byte(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
x.keysymMap = m
|
||||
}
|
||||
|
||||
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
|
||||
// Returns 0 if the keysym is not mapped.
|
||||
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
|
||||
return x.keysymMap[keysym]
|
||||
}
|
||||
|
||||
// SetClipboard sets the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) SetClipboard(text string) {
|
||||
if x.clipboardTool == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) resolveClipboardTool() {
|
||||
for _, name := range []string{"xclip", "xsel"} {
|
||||
path, err := exec.LookPath(name)
|
||||
if err == nil {
|
||||
x.clipboardTool = path
|
||||
x.clipboardToolName = name
|
||||
log.Debugf("clipboard tool resolved to %s", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
|
||||
}
|
||||
|
||||
// GetClipboard reads the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) GetClipboard() string {
|
||||
if x.clipboardTool == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) clipboardEnv() []string {
|
||||
env := []string{"DISPLAY=" + x.display}
|
||||
if auth := os.Getenv("XAUTHORITY"); auth != "" {
|
||||
env = append(env, "XAUTHORITY="+auth)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (x *X11InputInjector) Close() {
|
||||
x.conn.Close()
|
||||
}
|
||||
|
||||
var _ InputInjector = (*X11InputInjector)(nil)
|
||||
var _ ScreenCapturer = (*X11Poller)(nil)
|
||||
@@ -1,175 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Recording file format:
|
||||
//
|
||||
// Header: magic(6) + width(2) + height(2) + startTime(8) + metaLen(4) + metaJSON
|
||||
// Frames: offsetMs(4) + pngLen(4) + PNG image data
|
||||
//
|
||||
// Each frame is a PNG-encoded screenshot. Only changed frames are stored.
|
||||
const recMagic = "NBVNC\x01"
|
||||
|
||||
// RecordingMeta holds metadata written to the recording file header.
|
||||
type RecordingMeta struct {
|
||||
User string `json:"user,omitempty"`
|
||||
RemoteAddr string `json:"remote_addr"`
|
||||
JWTUser string `json:"jwt_user,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Encrypted bool `json:"encrypted,omitempty"`
|
||||
EphemeralKey string `json:"ephemeral_key,omitempty"`
|
||||
}
|
||||
|
||||
// vncRecorder writes VNC session frames to a recording file.
|
||||
type vncRecorder struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
startTime time.Time
|
||||
closed bool
|
||||
log *log.Entry
|
||||
prevFrame *image.RGBA
|
||||
pngEnc *png.Encoder
|
||||
pngBuf bytes.Buffer
|
||||
crypto *recCrypto
|
||||
}
|
||||
|
||||
func newVNCRecorder(dir string, width, height int, meta *RecordingMeta, encryptionKey string, logger *log.Entry) (*vncRecorder, error) {
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return nil, fmt.Errorf("create recording dir: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
filename := fmt.Sprintf("%s_vnc.rec", now.Format("20060102-150405"))
|
||||
filePath := filepath.Join(dir, filename)
|
||||
|
||||
f, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create recording file: %w", err)
|
||||
}
|
||||
|
||||
var crypto *recCrypto
|
||||
if encryptionKey != "" {
|
||||
var cryptoErr error
|
||||
crypto, cryptoErr = newRecCrypto(encryptionKey)
|
||||
if cryptoErr != nil {
|
||||
f.Close()
|
||||
os.Remove(filePath)
|
||||
return nil, fmt.Errorf("init encryption: %w", cryptoErr)
|
||||
}
|
||||
meta.Encrypted = true
|
||||
meta.EphemeralKey = base64.StdEncoding.EncodeToString(crypto.ephemeralPub)
|
||||
}
|
||||
|
||||
metaJSON, err := json.Marshal(meta)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
os.Remove(filePath)
|
||||
return nil, fmt.Errorf("marshal meta: %w", err)
|
||||
}
|
||||
|
||||
var hdr [6 + 2 + 2 + 8 + 4]byte
|
||||
copy(hdr[:6], recMagic)
|
||||
binary.BigEndian.PutUint16(hdr[6:8], uint16(width))
|
||||
binary.BigEndian.PutUint16(hdr[8:10], uint16(height))
|
||||
binary.BigEndian.PutUint64(hdr[10:18], uint64(now.UnixMilli()))
|
||||
binary.BigEndian.PutUint32(hdr[18:22], uint32(len(metaJSON)))
|
||||
|
||||
if _, err := f.Write(hdr[:]); err != nil {
|
||||
f.Close()
|
||||
os.Remove(filePath)
|
||||
return nil, fmt.Errorf("write header: %w", err)
|
||||
}
|
||||
if _, err := f.Write(metaJSON); err != nil {
|
||||
f.Close()
|
||||
os.Remove(filePath)
|
||||
return nil, fmt.Errorf("write meta: %w", err)
|
||||
}
|
||||
|
||||
r := &vncRecorder{
|
||||
file: f,
|
||||
startTime: now,
|
||||
log: logger.WithField("recording", filepath.Base(filePath)),
|
||||
pngEnc: &png.Encoder{CompressionLevel: png.BestSpeed},
|
||||
crypto: crypto,
|
||||
}
|
||||
if crypto != nil {
|
||||
r.log.Infof("VNC recording started (encrypted): %s", filePath)
|
||||
} else {
|
||||
r.log.Infof("VNC recording started: %s", filePath)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// writeFrame records a screen frame. Only writes if the frame differs from the previous one.
|
||||
func (r *vncRecorder) writeFrame(img *image.RGBA) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if r.prevFrame != nil && bytes.Equal(r.prevFrame.Pix, img.Pix) {
|
||||
return
|
||||
}
|
||||
|
||||
offsetMs := uint32(time.Since(r.startTime).Milliseconds())
|
||||
|
||||
r.pngBuf.Reset()
|
||||
if err := r.pngEnc.Encode(&r.pngBuf, img); err != nil {
|
||||
r.log.Debugf("encode PNG frame: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
frameData := r.pngBuf.Bytes()
|
||||
if r.crypto != nil {
|
||||
frameData = r.crypto.encrypt(frameData)
|
||||
}
|
||||
|
||||
var frameHdr [8]byte
|
||||
binary.BigEndian.PutUint32(frameHdr[0:4], offsetMs)
|
||||
binary.BigEndian.PutUint32(frameHdr[4:8], uint32(len(frameData)))
|
||||
|
||||
if _, err := r.file.Write(frameHdr[:]); err != nil {
|
||||
r.log.Debugf("write frame header: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := r.file.Write(frameData); err != nil {
|
||||
r.log.Debugf("write frame data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if r.prevFrame == nil {
|
||||
r.prevFrame = image.NewRGBA(img.Rect)
|
||||
}
|
||||
copy(r.prevFrame.Pix, img.Pix)
|
||||
}
|
||||
|
||||
func (r *vncRecorder) close() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.closed {
|
||||
return
|
||||
}
|
||||
r.closed = true
|
||||
|
||||
duration := time.Since(r.startTime)
|
||||
r.log.Infof("VNC recording stopped after %v", duration.Round(time.Millisecond))
|
||||
r.file.Close()
|
||||
}
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"image"
|
||||
"image/color"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func makeTestImage(w, h int, c color.RGBA) *image.RGBA {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i] = c.R
|
||||
img.Pix[i+1] = c.G
|
||||
img.Pix[i+2] = c.B
|
||||
img.Pix[i+3] = c.A
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func TestRecorderWriteAndReadHeader(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logger := log.WithField("test", t.Name())
|
||||
|
||||
meta := &RecordingMeta{
|
||||
User: "alice",
|
||||
RemoteAddr: "100.0.1.5:12345",
|
||||
JWTUser: "google|123",
|
||||
Mode: "session",
|
||||
}
|
||||
|
||||
rec, err := newVNCRecorder(dir, 800, 600, meta, "", logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Write some frames
|
||||
red := makeTestImage(800, 600, color.RGBA{255, 0, 0, 255})
|
||||
blue := makeTestImage(800, 600, color.RGBA{0, 0, 255, 255})
|
||||
|
||||
rec.writeFrame(red)
|
||||
rec.writeFrame(red) // duplicate, should be skipped
|
||||
rec.writeFrame(blue)
|
||||
rec.close()
|
||||
|
||||
// Read back the header
|
||||
files, err := os.ReadDir(dir)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, files, 1)
|
||||
|
||||
filePath := filepath.Join(dir, files[0].Name())
|
||||
header, err := ReadRecordingHeader(filePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 800, header.Width)
|
||||
assert.Equal(t, 600, header.Height)
|
||||
assert.Equal(t, "alice", header.Meta.User)
|
||||
assert.Equal(t, "100.0.1.5:12345", header.Meta.RemoteAddr)
|
||||
assert.Equal(t, "google|123", header.Meta.JWTUser)
|
||||
assert.Equal(t, "session", header.Meta.Mode)
|
||||
assert.False(t, header.Meta.Encrypted)
|
||||
|
||||
// Verify file is valid by checking size is reasonable
|
||||
fi, err := os.Stat(filePath)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, fi.Size(), int64(100), "recording should have content")
|
||||
}
|
||||
|
||||
func TestRecorderDuplicateFrameSkip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logger := log.WithField("test", t.Name())
|
||||
|
||||
rec, err := newVNCRecorder(dir, 100, 100, &RecordingMeta{RemoteAddr: "test"}, "", logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
img := makeTestImage(100, 100, color.RGBA{128, 128, 128, 255})
|
||||
|
||||
rec.writeFrame(img)
|
||||
rec.writeFrame(img) // duplicate
|
||||
rec.writeFrame(img) // duplicate
|
||||
rec.close()
|
||||
|
||||
files, _ := os.ReadDir(dir)
|
||||
filePath := filepath.Join(dir, files[0].Name())
|
||||
|
||||
// Count frames by parsing
|
||||
f, err := os.Open(filePath)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
|
||||
_, err = parseRecHeader(f)
|
||||
require.NoError(t, err)
|
||||
|
||||
frameCount := 0
|
||||
var hdr [8]byte
|
||||
for {
|
||||
if _, err := f.Read(hdr[:]); err != nil {
|
||||
break
|
||||
}
|
||||
pngLen := int64(hdr[4])<<24 | int64(hdr[5])<<16 | int64(hdr[6])<<8 | int64(hdr[7])
|
||||
f.Seek(pngLen, 1)
|
||||
frameCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, frameCount, "duplicate frames should be skipped")
|
||||
}
|
||||
|
||||
func TestRecorderEncrypted(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logger := log.WithField("test", t.Name())
|
||||
|
||||
// Generate admin keypair
|
||||
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||
|
||||
meta := &RecordingMeta{
|
||||
RemoteAddr: "100.0.1.5:12345",
|
||||
Mode: "attach",
|
||||
}
|
||||
|
||||
rec, err := newVNCRecorder(dir, 200, 150, meta, adminPubB64, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
img := makeTestImage(200, 150, color.RGBA{255, 0, 0, 255})
|
||||
rec.writeFrame(img)
|
||||
rec.close()
|
||||
|
||||
// Read header and verify encryption metadata
|
||||
files, _ := os.ReadDir(dir)
|
||||
filePath := filepath.Join(dir, files[0].Name())
|
||||
|
||||
header, err := ReadRecordingHeader(filePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, header.Meta.Encrypted)
|
||||
assert.NotEmpty(t, header.Meta.EphemeralKey)
|
||||
assert.Equal(t, 200, header.Width)
|
||||
assert.Equal(t, 150, header.Height)
|
||||
}
|
||||
|
||||
func TestRecorderEncryptedDecryptRoundtrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
logger := log.WithField("test", t.Name())
|
||||
|
||||
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
require.NoError(t, err)
|
||||
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||
|
||||
rec, err := newVNCRecorder(dir, 100, 100, &RecordingMeta{RemoteAddr: "test"}, adminPubB64, logger)
|
||||
require.NoError(t, err)
|
||||
|
||||
red := makeTestImage(100, 100, color.RGBA{255, 0, 0, 255})
|
||||
green := makeTestImage(100, 100, color.RGBA{0, 255, 0, 255})
|
||||
|
||||
rec.writeFrame(red)
|
||||
rec.writeFrame(green)
|
||||
rec.close()
|
||||
|
||||
// Read back and decrypt
|
||||
files, _ := os.ReadDir(dir)
|
||||
filePath := filepath.Join(dir, files[0].Name())
|
||||
|
||||
header, err := ReadRecordingHeader(filePath)
|
||||
require.NoError(t, err)
|
||||
require.True(t, header.Meta.Encrypted)
|
||||
|
||||
dec, err := DecryptRecording(adminPrivB64, header.Meta.EphemeralKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read raw frames and decrypt
|
||||
f, err := os.Open(filePath)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
|
||||
_, err = parseRecHeader(f)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedFrames := 0
|
||||
var hdr [8]byte
|
||||
for {
|
||||
if _, readErr := f.Read(hdr[:]); readErr != nil {
|
||||
break
|
||||
}
|
||||
frameLen := int(hdr[4])<<24 | int(hdr[5])<<16 | int(hdr[6])<<8 | int(hdr[7])
|
||||
ct := make([]byte, frameLen)
|
||||
f.Read(ct)
|
||||
|
||||
_, err := dec.Decrypt(ct)
|
||||
require.NoError(t, err, "frame %d decrypt should succeed", decryptedFrames)
|
||||
decryptedFrames++
|
||||
}
|
||||
|
||||
assert.Equal(t, 2, decryptedFrames)
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RecordingHeader holds parsed header data from a VNC recording file.
|
||||
type RecordingHeader struct {
|
||||
Width int
|
||||
Height int
|
||||
StartTime time.Time
|
||||
Meta RecordingMeta
|
||||
}
|
||||
|
||||
// ReadRecordingHeader parses and returns the recording header without loading frames.
|
||||
func ReadRecordingHeader(filePath string) (*RecordingHeader, error) {
|
||||
f, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
return parseRecHeader(f)
|
||||
}
|
||||
|
||||
func parseRecHeader(r io.Reader) (*RecordingHeader, error) {
|
||||
var hdr [22]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, fmt.Errorf("read header: %w", err)
|
||||
}
|
||||
if string(hdr[:6]) != recMagic {
|
||||
return nil, fmt.Errorf("invalid magic: %x", hdr[:6])
|
||||
}
|
||||
|
||||
width := int(binary.BigEndian.Uint16(hdr[6:8]))
|
||||
height := int(binary.BigEndian.Uint16(hdr[8:10]))
|
||||
startMs := int64(binary.BigEndian.Uint64(hdr[10:18]))
|
||||
metaLen := binary.BigEndian.Uint32(hdr[18:22])
|
||||
|
||||
if metaLen > 1<<20 {
|
||||
return nil, fmt.Errorf("meta too large: %d bytes", metaLen)
|
||||
}
|
||||
|
||||
metaJSON := make([]byte, metaLen)
|
||||
if _, err := io.ReadFull(r, metaJSON); err != nil {
|
||||
return nil, fmt.Errorf("read meta: %w", err)
|
||||
}
|
||||
|
||||
var meta RecordingMeta
|
||||
if err := json.Unmarshal(metaJSON, &meta); err != nil {
|
||||
return nil, fmt.Errorf("parse meta: %w", err)
|
||||
}
|
||||
|
||||
return &RecordingHeader{
|
||||
Width: width,
|
||||
Height: height,
|
||||
StartTime: time.UnixMilli(startMs),
|
||||
Meta: meta,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,264 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"crypto/des"
|
||||
"encoding/binary"
|
||||
"image"
|
||||
)
|
||||
|
||||
const (
|
||||
rfbProtocolVersion = "RFB 003.008\n"
|
||||
|
||||
secNone = 1
|
||||
secVNCAuth = 2
|
||||
|
||||
// Client message types.
|
||||
clientSetPixelFormat = 0
|
||||
clientSetEncodings = 2
|
||||
clientFramebufferUpdateRequest = 3
|
||||
clientKeyEvent = 4
|
||||
clientPointerEvent = 5
|
||||
clientCutText = 6
|
||||
|
||||
// Server message types.
|
||||
serverFramebufferUpdate = 0
|
||||
serverCutText = 3
|
||||
|
||||
// Encoding types.
|
||||
encRaw = 0
|
||||
encZlib = 6
|
||||
)
|
||||
|
||||
// serverPixelFormat is the default pixel format advertised by the server:
|
||||
// 32bpp RGBA, big-endian, true-colour, 8 bits per channel.
|
||||
var serverPixelFormat = [16]byte{
|
||||
32, // bits-per-pixel
|
||||
24, // depth
|
||||
1, // big-endian-flag
|
||||
1, // true-colour-flag
|
||||
0, 255, // red-max
|
||||
0, 255, // green-max
|
||||
0, 255, // blue-max
|
||||
16, // red-shift
|
||||
8, // green-shift
|
||||
0, // blue-shift
|
||||
0, 0, 0, // padding
|
||||
}
|
||||
|
||||
// clientPixelFormat holds the negotiated pixel format from the client.
|
||||
type clientPixelFormat struct {
|
||||
bpp uint8
|
||||
bigEndian uint8
|
||||
rMax uint16
|
||||
gMax uint16
|
||||
bMax uint16
|
||||
rShift uint8
|
||||
gShift uint8
|
||||
bShift uint8
|
||||
}
|
||||
|
||||
func defaultClientPixelFormat() clientPixelFormat {
|
||||
return clientPixelFormat{
|
||||
bpp: serverPixelFormat[0],
|
||||
bigEndian: serverPixelFormat[2],
|
||||
rMax: binary.BigEndian.Uint16(serverPixelFormat[4:6]),
|
||||
gMax: binary.BigEndian.Uint16(serverPixelFormat[6:8]),
|
||||
bMax: binary.BigEndian.Uint16(serverPixelFormat[8:10]),
|
||||
rShift: serverPixelFormat[10],
|
||||
gShift: serverPixelFormat[11],
|
||||
bShift: serverPixelFormat[12],
|
||||
}
|
||||
}
|
||||
|
||||
func parsePixelFormat(pf []byte) clientPixelFormat {
|
||||
return clientPixelFormat{
|
||||
bpp: pf[0],
|
||||
bigEndian: pf[2],
|
||||
rMax: binary.BigEndian.Uint16(pf[4:6]),
|
||||
gMax: binary.BigEndian.Uint16(pf[6:8]),
|
||||
bMax: binary.BigEndian.Uint16(pf[8:10]),
|
||||
rShift: pf[10],
|
||||
gShift: pf[11],
|
||||
bShift: pf[12],
|
||||
}
|
||||
}
|
||||
|
||||
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
|
||||
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
|
||||
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
|
||||
bytesPerPixel := max(int(pf.bpp)/8, 1)
|
||||
|
||||
pixelBytes := w * h * bytesPerPixel
|
||||
buf := make([]byte, 4+12+pixelBytes)
|
||||
|
||||
// FramebufferUpdate header.
|
||||
buf[0] = serverFramebufferUpdate
|
||||
buf[1] = 0 // padding
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
|
||||
// Rectangle header.
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
|
||||
|
||||
off := 16
|
||||
stride := img.Stride
|
||||
for row := y; row < y+h; row++ {
|
||||
for col := x; col < x+w; col++ {
|
||||
p := row*stride + col*4
|
||||
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
|
||||
|
||||
rv := uint32(r) * uint32(pf.rMax) / 255
|
||||
gv := uint32(g) * uint32(pf.gMax) / 255
|
||||
bv := uint32(b) * uint32(pf.bMax) / 255
|
||||
pixel := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
|
||||
|
||||
if pf.bigEndian != 0 {
|
||||
for i := range bytesPerPixel {
|
||||
buf[off+i] = byte(pixel >> uint((bytesPerPixel-1-i)*8))
|
||||
}
|
||||
} else {
|
||||
for i := range bytesPerPixel {
|
||||
buf[off+i] = byte(pixel >> uint(i*8))
|
||||
}
|
||||
}
|
||||
off += bytesPerPixel
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// vncAuthEncrypt encrypts a 16-byte challenge using the VNC DES scheme.
|
||||
func vncAuthEncrypt(challenge []byte, password string) []byte {
|
||||
key := make([]byte, 8)
|
||||
for i, c := range []byte(password) {
|
||||
if i >= 8 {
|
||||
break
|
||||
}
|
||||
key[i] = reverseBits(c)
|
||||
}
|
||||
block, _ := des.NewCipher(key)
|
||||
out := make([]byte, 16)
|
||||
block.Encrypt(out[:8], challenge[:8])
|
||||
block.Encrypt(out[8:], challenge[8:])
|
||||
return out
|
||||
}
|
||||
|
||||
func reverseBits(b byte) byte {
|
||||
var r byte
|
||||
for range 8 {
|
||||
r = (r << 1) | (b & 1)
|
||||
b >>= 1
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// encodeZlibRect encodes a framebuffer region using Zlib compression.
|
||||
// The zlib stream is continuous for the entire VNC session: noVNC creates
|
||||
// one inflate context at startup and reuses it for all zlib-encoded rects.
|
||||
// We must NOT reset the zlib writer between calls.
|
||||
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, zw *zlib.Writer, zbuf *bytes.Buffer) []byte {
|
||||
bytesPerPixel := max(int(pf.bpp)/8, 1)
|
||||
|
||||
// Clear the output buffer but keep the deflate dictionary intact.
|
||||
zbuf.Reset()
|
||||
|
||||
stride := img.Stride
|
||||
pixel := make([]byte, bytesPerPixel)
|
||||
for row := y; row < y+h; row++ {
|
||||
for col := x; col < x+w; col++ {
|
||||
p := row*stride + col*4
|
||||
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
|
||||
|
||||
rv := uint32(r) * uint32(pf.rMax) / 255
|
||||
gv := uint32(g) * uint32(pf.gMax) / 255
|
||||
bv := uint32(b) * uint32(pf.bMax) / 255
|
||||
val := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
|
||||
|
||||
if pf.bigEndian != 0 {
|
||||
for i := range bytesPerPixel {
|
||||
pixel[i] = byte(val >> uint((bytesPerPixel-1-i)*8))
|
||||
}
|
||||
} else {
|
||||
for i := range bytesPerPixel {
|
||||
pixel[i] = byte(val >> uint(i*8))
|
||||
}
|
||||
}
|
||||
zw.Write(pixel)
|
||||
}
|
||||
}
|
||||
zw.Flush()
|
||||
|
||||
compressed := zbuf.Bytes()
|
||||
|
||||
// Build the FramebufferUpdate message.
|
||||
buf := make([]byte, 4+12+4+len(compressed))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
buf[1] = 0
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1) // 1 rectangle
|
||||
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
|
||||
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
|
||||
copy(buf[20:], compressed)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// diffRects compares two RGBA images and returns a list of dirty rectangles.
|
||||
// Divides the screen into tiles and checks each for changes.
|
||||
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||
if prev == nil {
|
||||
return [][4]int{{0, 0, w, h}}
|
||||
}
|
||||
|
||||
var rects [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := min(tileSize, h-ty)
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := min(tileSize, w-tx)
|
||||
if tileChanged(prev, cur, tx, ty, tw, th) {
|
||||
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
}
|
||||
return rects
|
||||
}
|
||||
|
||||
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
|
||||
stride := prev.Stride
|
||||
for row := y; row < y+h; row++ {
|
||||
off := row*stride + x*4
|
||||
end := off + w*4
|
||||
prevRow := prev.Pix[off:end]
|
||||
curRow := cur.Pix[off:end]
|
||||
if !bytes.Equal(prevRow, curRow) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// zlibState holds the persistent zlib writer and buffer for a session.
|
||||
type zlibState struct {
|
||||
buf *bytes.Buffer
|
||||
w *zlib.Writer
|
||||
}
|
||||
|
||||
func newZlibState() *zlibState {
|
||||
buf := &bytes.Buffer{}
|
||||
w, _ := zlib.NewWriterLevel(buf, zlib.BestSpeed)
|
||||
return &zlibState{buf: buf, w: w}
|
||||
}
|
||||
|
||||
func (z *zlibState) Close() error {
|
||||
return z.w.Close()
|
||||
}
|
||||
@@ -1,690 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
// Connection modes sent by the client in the session header.
|
||||
const (
|
||||
ModeAttach byte = 0 // Capture current display
|
||||
ModeSession byte = 1 // Virtual session as specified user
|
||||
)
|
||||
|
||||
// RFB security-failure reason codes sent to the client. These prefixes are
|
||||
// stable so dashboard/noVNC integrations can branch on them without parsing
|
||||
// free text. Format: "CODE: human message".
|
||||
const (
|
||||
RejectCodeJWTMissing = "AUTH_JWT_MISSING"
|
||||
RejectCodeJWTExpired = "AUTH_JWT_EXPIRED"
|
||||
RejectCodeJWTInvalid = "AUTH_JWT_INVALID"
|
||||
RejectCodeAuthForbidden = "AUTH_FORBIDDEN"
|
||||
RejectCodeAuthConfig = "AUTH_CONFIG"
|
||||
RejectCodeSessionError = "SESSION_ERROR"
|
||||
RejectCodeCapturerError = "CAPTURER_ERROR"
|
||||
RejectCodeUnsupportedOS = "UNSUPPORTED"
|
||||
RejectCodeBadRequest = "BAD_REQUEST"
|
||||
)
|
||||
|
||||
// EnvVNCDisableDownscale disables any platform-specific framebuffer
|
||||
// downscaling (e.g. Retina 2:1). Set to 1/true to send the native resolution.
|
||||
const EnvVNCDisableDownscale = "NB_VNC_DISABLE_DOWNSCALE"
|
||||
|
||||
// ScreenCapturer grabs desktop frames for the VNC server.
|
||||
type ScreenCapturer interface {
|
||||
// Width returns the current screen width in pixels.
|
||||
Width() int
|
||||
// Height returns the current screen height in pixels.
|
||||
Height() int
|
||||
// Capture returns the current desktop as an RGBA image.
|
||||
Capture() (*image.RGBA, error)
|
||||
}
|
||||
|
||||
// InputInjector delivers keyboard and mouse events to the OS.
|
||||
type InputInjector interface {
|
||||
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||
InjectKey(keysym uint32, down bool)
|
||||
// InjectPointer simulates mouse movement and button state.
|
||||
InjectPointer(buttonMask uint8, x, y, serverW, serverH int)
|
||||
// SetClipboard sets the system clipboard to the given text.
|
||||
SetClipboard(text string)
|
||||
// GetClipboard returns the current system clipboard text.
|
||||
GetClipboard() string
|
||||
}
|
||||
|
||||
// JWTConfig holds JWT validation configuration for VNC auth.
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
// connectionHeader is sent by the client before the RFB handshake to specify
|
||||
// the VNC session mode and authenticate.
|
||||
type connectionHeader struct {
|
||||
mode byte
|
||||
username string
|
||||
jwt string
|
||||
sessionID uint32 // Windows session ID (0 = console/auto)
|
||||
}
|
||||
|
||||
// Server is the embedded VNC server that listens on the WireGuard interface.
|
||||
// It supports two operating modes:
|
||||
// - Direct mode: captures the screen and handles VNC sessions in-process.
|
||||
// Used when running in a user session with desktop access.
|
||||
// - Service mode: proxies VNC connections to an agent process spawned in
|
||||
// the active console session. Used when running as a Windows service in
|
||||
// Session 0.
|
||||
//
|
||||
// Within direct mode, each connection can request one of two session modes
|
||||
// via the connection header:
|
||||
// - Attach: capture the current physical display.
|
||||
// - Session: start a virtual Xvfb display as the requested user.
|
||||
type Server struct {
|
||||
capturer ScreenCapturer
|
||||
injector InputInjector
|
||||
password string
|
||||
serviceMode bool
|
||||
disableAuth bool
|
||||
localAddr netip.Addr // NetBird WireGuard IP this server is bound to
|
||||
network netip.Prefix // NetBird overlay network
|
||||
log *log.Entry
|
||||
|
||||
recordingDir string // when set, VNC sessions are recorded to this directory
|
||||
recordingEncKey string // base64-encoded X25519 public key for encrypting recordings
|
||||
|
||||
mu sync.Mutex
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
vmgr virtualSessionManager
|
||||
jwtConfig *JWTConfig
|
||||
jwtValidator *nbjwt.Validator
|
||||
jwtExtractor *nbjwt.ClaimsExtractor
|
||||
authorizer *sshauth.Authorizer
|
||||
netstackNet *netstack.Net
|
||||
agentToken []byte // raw token bytes for agent-mode auth
|
||||
}
|
||||
|
||||
// vncSession provides capturer and injector for a virtual display session.
|
||||
type vncSession interface {
|
||||
Capturer() ScreenCapturer
|
||||
Injector() InputInjector
|
||||
Display() string
|
||||
ClientConnect()
|
||||
ClientDisconnect()
|
||||
}
|
||||
|
||||
// virtualSessionManager is implemented by sessionManager on Linux.
|
||||
type virtualSessionManager interface {
|
||||
GetOrCreate(username string) (vncSession, error)
|
||||
StopAll()
|
||||
}
|
||||
|
||||
// New creates a VNC server with the given screen capturer and input injector.
|
||||
func New(capturer ScreenCapturer, injector InputInjector, password string) *Server {
|
||||
return &Server{
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
password: password,
|
||||
authorizer: sshauth.NewAuthorizer(),
|
||||
log: log.WithField("component", "vnc-server"),
|
||||
}
|
||||
}
|
||||
|
||||
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
|
||||
func (s *Server) SetServiceMode(enabled bool) {
|
||||
s.serviceMode = enabled
|
||||
}
|
||||
|
||||
// SetJWTConfig configures JWT authentication for VNC connections.
|
||||
// Pass nil to disable JWT (public mode).
|
||||
func (s *Server) SetJWTConfig(config *JWTConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jwtConfig = config
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
}
|
||||
|
||||
// SetDisableAuth disables authentication entirely.
|
||||
func (s *Server) SetDisableAuth(disable bool) {
|
||||
s.disableAuth = disable
|
||||
}
|
||||
|
||||
// SetAgentToken sets a hex-encoded token that must be presented by incoming
|
||||
// connections before any VNC data. Used in agent mode to verify that only the
|
||||
// trusted service process connects.
|
||||
func (s *Server) SetAgentToken(hexToken string) {
|
||||
if hexToken == "" {
|
||||
return
|
||||
}
|
||||
b, err := hex.DecodeString(hexToken)
|
||||
if err != nil {
|
||||
s.log.Warnf("invalid agent token: %v", err)
|
||||
return
|
||||
}
|
||||
s.agentToken = b
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace-only listening.
|
||||
// When set, the VNC server listens via netstack instead of a real OS socket.
|
||||
func (s *Server) SetNetstackNet(n *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = n
|
||||
}
|
||||
|
||||
// SetRecordingDir enables VNC session recording to the given directory.
|
||||
func (s *Server) SetRecordingDir(dir string) {
|
||||
s.recordingDir = dir
|
||||
}
|
||||
|
||||
// SetRecordingEncryptionKey sets the base64-encoded X25519 public key for encrypting recordings.
|
||||
func (s *Server) SetRecordingEncryptionKey(key string) {
|
||||
s.recordingEncKey = key
|
||||
}
|
||||
|
||||
// UpdateVNCAuth updates the fine-grained authorization configuration.
|
||||
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
s.authorizer.Update(config)
|
||||
}
|
||||
|
||||
// Start begins listening for VNC connections on the given address.
|
||||
// network is the NetBird overlay prefix used to validate connection sources.
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return fmt.Errorf("server already running")
|
||||
}
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
s.vmgr = s.platformSessionManager()
|
||||
s.localAddr = addr.Addr()
|
||||
s.network = network
|
||||
|
||||
var listener net.Listener
|
||||
var listenDesc string
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on netstack %s: %w", addr, err)
|
||||
}
|
||||
listener = ln
|
||||
listenDesc = fmt.Sprintf("netstack %s", addr)
|
||||
} else {
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addr)
|
||||
ln, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
listener = ln
|
||||
listenDesc = addr.String()
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
if s.serviceMode {
|
||||
s.platformInit()
|
||||
}
|
||||
|
||||
if s.serviceMode {
|
||||
go s.serviceAcceptLoop()
|
||||
} else {
|
||||
go s.acceptLoop()
|
||||
}
|
||||
|
||||
s.log.Infof("started on %s (service_mode=%v)", listenDesc, s.serviceMode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the server and closes all connections.
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
s.cancel = nil
|
||||
}
|
||||
|
||||
if s.vmgr != nil {
|
||||
s.vmgr.StopAll()
|
||||
}
|
||||
|
||||
if c, ok := s.capturer.(interface{ Close() }); ok {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("close VNC listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Info("stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop handles VNC connections directly (user session mode).
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) validateCapturer(cap ScreenCapturer) error {
|
||||
// Quick check first: if already ready, return immediately.
|
||||
if cap.Width() > 0 && cap.Height() > 0 {
|
||||
return nil
|
||||
}
|
||||
// Capturer not ready: poke any retry loop that supports it so it doesn't
|
||||
// wait out its full backoff (e.g. macOS waiting for Screen Recording).
|
||||
if w, ok := cap.(interface{ Wake() }); ok {
|
||||
w.Wake()
|
||||
}
|
||||
// Wait up to 5s for the capturer to become ready.
|
||||
for range 50 {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if cap.Width() > 0 && cap.Height() > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("no display available (check X11 on Linux or Screen Recording permission on macOS)")
|
||||
}
|
||||
|
||||
// isAllowedSource rejects connections from outside the NetBird overlay network
|
||||
// and from the local WireGuard IP (prevents local privilege escalation).
|
||||
// Matches the SSH server's connectionValidator logic.
|
||||
func (s *Server) isAllowedSource(addr net.Addr) bool {
|
||||
tcpAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
s.log.Warnf("connection rejected: non-TCP address %s", addr)
|
||||
return false
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
s.log.Warnf("connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||
return false
|
||||
}
|
||||
remoteIP = remoteIP.Unmap()
|
||||
|
||||
if remoteIP.IsLoopback() && s.localAddr.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
if remoteIP == s.localAddr {
|
||||
s.log.Warnf("connection rejected from own IP %s", remoteIP)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.network.IsValid() && !s.network.Contains(remoteIP) {
|
||||
s.log.Warnf("connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.agentToken) > 0 {
|
||||
buf := make([]byte, len(s.agentToken))
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
connLog.Debugf("set agent token deadline: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
connLog.Warnf("agent auth: read token: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
if subtle.ConstantTimeCompare(buf, s.agentToken) != 1 {
|
||||
connLog.Warn("agent auth: invalid token, rejecting")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
header, err := readConnectionHeader(conn)
|
||||
if err != nil {
|
||||
connLog.Warnf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return
|
||||
}
|
||||
jwtUserID, err := s.authenticateJWT(header)
|
||||
if err != nil {
|
||||
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
connLog = connLog.WithField("jwt_user", jwtUserID)
|
||||
}
|
||||
|
||||
var capturer ScreenCapturer
|
||||
var injector InputInjector
|
||||
|
||||
switch header.mode {
|
||||
case ModeSession:
|
||||
if s.vmgr == nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeUnsupportedOS, "virtual sessions not supported on this platform"))
|
||||
connLog.Warn("session rejected: not supported on this platform")
|
||||
return
|
||||
}
|
||||
if header.username == "" {
|
||||
rejectConnection(conn, codeMessage(RejectCodeBadRequest, "session mode requires a username"))
|
||||
connLog.Warn("session rejected: no username provided")
|
||||
return
|
||||
}
|
||||
vs, err := s.vmgr.GetOrCreate(header.username)
|
||||
if err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeSessionError, fmt.Sprintf("create virtual session: %v", err)))
|
||||
connLog.Warnf("create virtual session for %s: %v", header.username, err)
|
||||
return
|
||||
}
|
||||
capturer = vs.Capturer()
|
||||
injector = vs.Injector()
|
||||
vs.ClientConnect()
|
||||
defer vs.ClientDisconnect()
|
||||
connLog = connLog.WithField("vnc_user", header.username)
|
||||
connLog.Infof("session mode: user=%s display=%s", header.username, vs.Display())
|
||||
|
||||
default:
|
||||
capturer = s.capturer
|
||||
injector = s.injector
|
||||
if cc, ok := capturer.(interface{ ClientConnect() }); ok {
|
||||
cc.ClientConnect()
|
||||
}
|
||||
defer func() {
|
||||
if cd, ok := capturer.(interface{ ClientDisconnect() }); ok {
|
||||
cd.ClientDisconnect()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := s.validateCapturer(capturer); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeCapturerError, fmt.Sprintf("screen capturer: %v", err)))
|
||||
connLog.Warnf("capturer not ready: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var rec *vncRecorder
|
||||
if s.recordingDir != "" {
|
||||
mode := "attach"
|
||||
if header.mode == ModeSession {
|
||||
mode = "session"
|
||||
}
|
||||
jwtUser, _ := connLog.Data["jwt_user"].(string)
|
||||
var err error
|
||||
rec, err = newVNCRecorder(s.recordingDir, capturer.Width(), capturer.Height(), &RecordingMeta{
|
||||
User: header.username,
|
||||
RemoteAddr: conn.RemoteAddr().String(),
|
||||
JWTUser: jwtUser,
|
||||
Mode: mode,
|
||||
}, s.recordingEncKey, connLog)
|
||||
if err != nil {
|
||||
connLog.Warnf("start VNC recording: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
sess := &session{
|
||||
conn: conn,
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
serverW: capturer.Width(),
|
||||
serverH: capturer.Height(),
|
||||
password: s.password,
|
||||
log: connLog,
|
||||
recorder: rec,
|
||||
}
|
||||
sess.serve()
|
||||
}
|
||||
|
||||
// codeMessage formats a stable reject code with a human-readable message.
|
||||
// Dashboards split on the first ": " to recover the code without parsing the
|
||||
// free-text suffix.
|
||||
func codeMessage(code, msg string) string {
|
||||
return code + ": " + msg
|
||||
}
|
||||
|
||||
// jwtErrorCode maps a JWT auth error to a stable reject code.
|
||||
func jwtErrorCode(err error) string {
|
||||
if err == nil {
|
||||
return RejectCodeJWTInvalid
|
||||
}
|
||||
if errors.Is(err, nbjwt.ErrTokenExpired) {
|
||||
return RejectCodeJWTExpired
|
||||
}
|
||||
msg := err.Error()
|
||||
switch {
|
||||
case strings.Contains(msg, "JWT required but not provided"):
|
||||
return RejectCodeJWTMissing
|
||||
case strings.Contains(msg, "authorize") || strings.Contains(msg, "not authorized"):
|
||||
return RejectCodeAuthForbidden
|
||||
default:
|
||||
return RejectCodeJWTInvalid
|
||||
}
|
||||
}
|
||||
|
||||
// rejectConnection sends a minimal RFB handshake with a security failure
|
||||
// reason, so VNC clients display the error message instead of a generic
|
||||
// "unexpected disconnect."
|
||||
func rejectConnection(conn net.Conn, reason string) {
|
||||
defer conn.Close()
|
||||
// RFB 3.8 server version.
|
||||
io.WriteString(conn, "RFB 003.008\n")
|
||||
// Read client version (12 bytes), ignore errors.
|
||||
var clientVer [12]byte
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
io.ReadFull(conn, clientVer[:])
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
// Send 0 security types = connection failed, followed by reason.
|
||||
msg := []byte(reason)
|
||||
buf := make([]byte, 1+4+len(msg))
|
||||
buf[0] = 0 // 0 security types = failure
|
||||
binary.BigEndian.PutUint32(buf[1:5], uint32(len(msg)))
|
||||
copy(buf[5:], msg)
|
||||
conn.Write(buf)
|
||||
}
|
||||
|
||||
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
|
||||
|
||||
// authenticateJWT validates the JWT from the connection header and checks
|
||||
// authorization. For attach mode, just checks membership in the authorized
|
||||
// user list. For session mode, additionally validates the OS user mapping.
|
||||
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
|
||||
if header.jwt == "" {
|
||||
return "", fmt.Errorf("JWT required but not provided")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("initialize JWT validator: %w", err)
|
||||
}
|
||||
validator := s.jwtValidator
|
||||
extractor := s.jwtExtractor
|
||||
s.mu.Unlock()
|
||||
|
||||
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("validate JWT: %w", err)
|
||||
}
|
||||
|
||||
if err := s.checkTokenAge(token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extract user from JWT: %w", err)
|
||||
}
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("JWT has no user ID")
|
||||
}
|
||||
|
||||
switch header.mode {
|
||||
case ModeSession:
|
||||
// Session mode: check user + OS username mapping.
|
||||
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
|
||||
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
|
||||
}
|
||||
default:
|
||||
// Attach mode: just check user is in the authorized list (wildcard OS user).
|
||||
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
|
||||
return "", fmt.Errorf("user not authorized for VNC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return userAuth.UserId, nil
|
||||
}
|
||||
|
||||
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
if s.jwtConfig == nil {
|
||||
return fmt.Errorf("no JWT config")
|
||||
}
|
||||
|
||||
s.jwtValidator = nbjwt.NewValidator(
|
||||
s.jwtConfig.Issuer,
|
||||
s.jwtConfig.Audiences,
|
||||
s.jwtConfig.KeysLocation,
|
||||
false,
|
||||
)
|
||||
|
||||
opts := []nbjwt.ClaimsExtractorOption{nbjwt.WithAudience(s.jwtConfig.Audiences[0])}
|
||||
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
|
||||
opts = append(opts, nbjwt.WithUserIDClaim(claim))
|
||||
}
|
||||
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token) error {
|
||||
maxAge := defaultJWTMaxTokenAge
|
||||
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
|
||||
maxAge = int(s.jwtConfig.MaxTokenAge)
|
||||
}
|
||||
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
|
||||
}
|
||||
|
||||
// readConnectionHeader reads the NetBird VNC session header from the connection.
|
||||
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
|
||||
//
|
||||
// [jwt_len: 2 bytes BE] [jwt: N bytes]
|
||||
//
|
||||
// Uses a short timeout: our WASM proxy sends the header immediately after
|
||||
// connecting. Standard VNC clients don't send anything first (server speaks
|
||||
// first in RFB), so they time out and get the default attach mode.
|
||||
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
|
||||
var hdr [3]byte
|
||||
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||
// Timeout or error: assume no header, use attach mode.
|
||||
return &connectionHeader{mode: ModeAttach}, nil
|
||||
}
|
||||
|
||||
// Restore a longer deadline for reading variable-length fields.
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
|
||||
mode := hdr[0]
|
||||
usernameLen := binary.BigEndian.Uint16(hdr[1:3])
|
||||
|
||||
var username string
|
||||
if usernameLen > 0 {
|
||||
if usernameLen > 256 {
|
||||
return nil, fmt.Errorf("username too long: %d", usernameLen)
|
||||
}
|
||||
buf := make([]byte, usernameLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("read username: %w", err)
|
||||
}
|
||||
username = string(buf)
|
||||
}
|
||||
|
||||
// Read JWT token length and data.
|
||||
var jwtLenBuf [2]byte
|
||||
var jwtToken string
|
||||
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
|
||||
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
|
||||
if jwtLen > 0 && jwtLen < 8192 {
|
||||
buf := make([]byte, jwtLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("read JWT: %w", err)
|
||||
}
|
||||
jwtToken = string(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
|
||||
var sessionID uint32
|
||||
var sidBuf [4]byte
|
||||
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
|
||||
sessionID = binary.BigEndian.Uint32(sidBuf[:])
|
||||
}
|
||||
|
||||
return &connectionHeader{mode: mode, username: username, jwt: jwtToken, sessionID: sessionID}, nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {}
|
||||
|
||||
// serviceAcceptLoop is not supported on macOS.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on macOS, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {}
|
||||
|
||||
// serviceAcceptLoop is not supported on non-Windows platforms.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on this platform, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCapturer returns a 100x100 image for test sessions.
|
||||
type testCapturer struct{}
|
||||
|
||||
func (t *testCapturer) Width() int { return 100 }
|
||||
func (t *testCapturer) Height() int { return 100 }
|
||||
func (t *testCapturer) Capture() (*image.RGBA, error) { return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil }
|
||||
|
||||
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
||||
srv.SetDisableAuth(disableAuth)
|
||||
if jwtConfig != nil {
|
||||
srv.SetJWTConfig(jwtConfig)
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv
|
||||
}
|
||||
|
||||
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false, nil)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header: attach mode, no username, no JWT.
|
||||
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version then security failure.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version to proceed through handshake.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read security types: 0 means failure, followed by reason.
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
|
||||
}
|
||||
|
||||
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, true, nil)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header: attach mode, no username, no JWT.
|
||||
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get security types (not 0 = failure).
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
||||
}
|
||||
|
||||
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
|
||||
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
|
||||
addr, _ := startTestServer(t, false, &JWTConfig{
|
||||
Issuer: "https://example.com",
|
||||
KeysLocation: "https://example.com/.well-known/jwks.json",
|
||||
Audiences: []string{"test"},
|
||||
})
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header with empty JWT.
|
||||
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
sasDLL = windows.NewLazySystemDLL("sas.dll")
|
||||
procSendSAS = sasDLL.NewProc("SendSAS")
|
||||
|
||||
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||
)
|
||||
|
||||
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
|
||||
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
|
||||
// local processes from triggering the Secure Attention Sequence.
|
||||
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
|
||||
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
|
||||
// to the interactive user (IU) so the VNC agent in the console session can
|
||||
// signal it. Other local users and network users are denied.
|
||||
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sd uintptr
|
||||
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
|
||||
uintptr(unsafe.Pointer(sddl)),
|
||||
1, // SDDL_REVISION_1
|
||||
uintptr(unsafe.Pointer(&sd)),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, lerr
|
||||
}
|
||||
return &windows.SecurityAttributes{
|
||||
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
|
||||
InheritHandle: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
|
||||
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
|
||||
// SendSAS silently does nothing on most Windows editions.
|
||||
func enableSoftwareSAS() {
|
||||
key, _, err := registry.CreateKey(
|
||||
registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||
registry.SET_VALUE,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
|
||||
log.Warnf("set SoftwareSASGeneration: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
|
||||
}
|
||||
|
||||
// startSASListener creates a named event with a restricted DACL and waits for
|
||||
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
|
||||
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
|
||||
// Only SYSTEM processes can open the event.
|
||||
func startSASListener() {
|
||||
enableSoftwareSAS()
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS listener UTF16: %v", err)
|
||||
return
|
||||
}
|
||||
sa, err := sasSecurityAttributes()
|
||||
if err != nil {
|
||||
log.Warnf("build SAS security descriptor: %v", err)
|
||||
return
|
||||
}
|
||||
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
|
||||
if err != nil {
|
||||
log.Warnf("SAS CreateEvent: %v", err)
|
||||
return
|
||||
}
|
||||
log.Info("SAS listener ready (Session 0)")
|
||||
go func() {
|
||||
defer windows.CloseHandle(ev)
|
||||
for {
|
||||
ret, _ := windows.WaitForSingleObject(ev, windows.INFINITE)
|
||||
if ret == windows.WAIT_OBJECT_0 {
|
||||
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
|
||||
if r == 0 {
|
||||
log.Warnf("SendSAS: %v", sasErr)
|
||||
} else {
|
||||
log.Info("SendSAS called from Session 0")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// enablePrivilege enables a named privilege on the current process token.
|
||||
func enablePrivilege(name string) error {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
|
||||
return err
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var luid windows.LUID
|
||||
namePtr, _ := windows.UTF16PtrFromString(name)
|
||||
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
|
||||
return err
|
||||
}
|
||||
tp := windows.Tokenprivileges{PrivilegeCount: 1}
|
||||
tp.Privileges[0].Luid = luid
|
||||
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// platformInit starts the SAS listener and enables privileges needed for
|
||||
// Session 0 operations (agent spawning, SendSAS).
|
||||
func (s *Server) platformInit() {
|
||||
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
|
||||
if err := enablePrivilege(priv); err != nil {
|
||||
log.Debugf("enable %s: %v", priv, err)
|
||||
}
|
||||
}
|
||||
startSASListener()
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in Session 0. It validates source IP and
|
||||
// authenticates via JWT before proxying connections to the user-session agent.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
|
||||
sm := newSessionManager(agentPort)
|
||||
go sm.run()
|
||||
|
||||
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%s", agentPort)
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
sm.Stop()
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleServiceConnection(conn, sm)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceConnection validates the source IP and JWT, then proxies
|
||||
// the connection (with header bytes replayed) to the agent.
|
||||
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return
|
||||
}
|
||||
if _, err := s.authenticateJWT(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Replay buffered header bytes + remaining stream to the agent.
|
||||
replayConn := &prefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
proxyToAgent(replayConn, agentPort, sm.AuthToken())
|
||||
}
|
||||
|
||||
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
|
||||
type prefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *prefixConn) Read(b []byte) (int, error) {
|
||||
return p.Reader.Read(b)
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {}
|
||||
|
||||
// serviceAcceptLoop is not supported on Linux.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return newSessionManager(s.log)
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
readDeadline = 60 * time.Second
|
||||
maxCutTextBytes = 1 << 20 // 1 MiB
|
||||
)
|
||||
|
||||
const tileSize = 64 // pixels per tile for dirty-rect detection
|
||||
|
||||
type session struct {
|
||||
conn net.Conn
|
||||
capturer ScreenCapturer
|
||||
injector InputInjector
|
||||
serverW int
|
||||
serverH int
|
||||
password string
|
||||
log *log.Entry
|
||||
recorder *vncRecorder
|
||||
|
||||
writeMu sync.Mutex
|
||||
pf clientPixelFormat
|
||||
useZlib bool
|
||||
zlib *zlibState
|
||||
prevFrame *image.RGBA
|
||||
idleFrames int
|
||||
}
|
||||
|
||||
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
|
||||
|
||||
// serve runs the full RFB session lifecycle.
|
||||
func (s *session) serve() {
|
||||
defer s.conn.Close()
|
||||
if s.recorder != nil {
|
||||
defer s.recorder.close()
|
||||
}
|
||||
s.pf = defaultClientPixelFormat()
|
||||
|
||||
if err := s.handshake(); err != nil {
|
||||
s.log.Warnf("handshake with %s: %v", s.addr(), err)
|
||||
return
|
||||
}
|
||||
s.log.Infof("client connected: %s", s.addr())
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go s.clipboardPoll(done)
|
||||
|
||||
if err := s.messageLoop(); err != nil && err != io.EOF {
|
||||
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
|
||||
} else {
|
||||
s.log.Infof("client disconnected: %s", s.addr())
|
||||
}
|
||||
}
|
||||
|
||||
// clipboardPoll periodically checks the server-side clipboard and sends
|
||||
// changes to the VNC client. Only runs during active sessions.
|
||||
func (s *session) clipboardPoll(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastClip string
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
text := s.injector.GetClipboard()
|
||||
if len(text) > maxCutTextBytes {
|
||||
text = text[:maxCutTextBytes]
|
||||
}
|
||||
if text != "" && text != lastClip {
|
||||
lastClip = text
|
||||
if err := s.sendServerCutText(text); err != nil {
|
||||
s.log.Debugf("send clipboard to client: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handshake() error {
|
||||
// Send protocol version.
|
||||
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
|
||||
return fmt.Errorf("send version: %w", err)
|
||||
}
|
||||
|
||||
// Read client version.
|
||||
var clientVer [12]byte
|
||||
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
|
||||
return fmt.Errorf("read client version: %w", err)
|
||||
}
|
||||
|
||||
// Send supported security types.
|
||||
if err := s.sendSecurityTypes(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read chosen security type.
|
||||
var secType [1]byte
|
||||
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
|
||||
return fmt.Errorf("read security type: %w", err)
|
||||
}
|
||||
|
||||
if err := s.handleSecurity(secType[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read ClientInit.
|
||||
var clientInit [1]byte
|
||||
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
|
||||
return fmt.Errorf("read ClientInit: %w", err)
|
||||
}
|
||||
|
||||
return s.sendServerInit()
|
||||
}
|
||||
|
||||
func (s *session) sendSecurityTypes() error {
|
||||
if s.password == "" {
|
||||
_, err := s.conn.Write([]byte{1, secNone})
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write([]byte{1, secVNCAuth})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) handleSecurity(secType byte) error {
|
||||
switch secType {
|
||||
case secVNCAuth:
|
||||
return s.doVNCAuth()
|
||||
case secNone:
|
||||
return binary.Write(s.conn, binary.BigEndian, uint32(0))
|
||||
default:
|
||||
return fmt.Errorf("unsupported security type: %d", secType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) doVNCAuth() error {
|
||||
challenge := make([]byte, 16)
|
||||
if _, err := rand.Read(challenge); err != nil {
|
||||
return fmt.Errorf("generate challenge: %w", err)
|
||||
}
|
||||
if _, err := s.conn.Write(challenge); err != nil {
|
||||
return fmt.Errorf("send challenge: %w", err)
|
||||
}
|
||||
|
||||
response := make([]byte, 16)
|
||||
if _, err := io.ReadFull(s.conn, response); err != nil {
|
||||
return fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
|
||||
var result uint32
|
||||
if s.password != "" {
|
||||
expected := vncAuthEncrypt(challenge, s.password)
|
||||
if !bytes.Equal(expected, response) {
|
||||
result = 1
|
||||
}
|
||||
}
|
||||
|
||||
if err := binary.Write(s.conn, binary.BigEndian, result); err != nil {
|
||||
return fmt.Errorf("send auth result: %w", err)
|
||||
}
|
||||
if result != 0 {
|
||||
msg := "authentication failed"
|
||||
_ = binary.Write(s.conn, binary.BigEndian, uint32(len(msg)))
|
||||
_, _ = s.conn.Write([]byte(msg))
|
||||
return fmt.Errorf("authentication failed from %s", s.addr())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) sendServerInit() error {
|
||||
name := []byte("NetBird VNC")
|
||||
buf := make([]byte, 0, 4+16+4+len(name))
|
||||
|
||||
// Framebuffer width and height.
|
||||
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
|
||||
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
|
||||
|
||||
// Server pixel format.
|
||||
buf = append(buf, serverPixelFormat[:]...)
|
||||
|
||||
// Desktop name.
|
||||
buf = append(buf,
|
||||
byte(len(name)>>24), byte(len(name)>>16),
|
||||
byte(len(name)>>8), byte(len(name)),
|
||||
)
|
||||
buf = append(buf, name...)
|
||||
|
||||
_, err := s.conn.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) messageLoop() error {
|
||||
for {
|
||||
var msgType [1]byte
|
||||
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
|
||||
return fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = s.conn.SetDeadline(time.Time{})
|
||||
|
||||
switch msgType[0] {
|
||||
case clientSetPixelFormat:
|
||||
if err := s.handleSetPixelFormat(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientSetEncodings:
|
||||
if err := s.handleSetEncodings(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientFramebufferUpdateRequest:
|
||||
if err := s.handleFBUpdateRequest(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientKeyEvent:
|
||||
if err := s.handleKeyEvent(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientPointerEvent:
|
||||
if err := s.handlePointerEvent(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientCutText:
|
||||
if err := s.handleCutText(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown client message type: %d", msgType[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleSetPixelFormat() error {
|
||||
var buf [19]byte // 3 padding + 16 pixel format
|
||||
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
|
||||
return fmt.Errorf("read SetPixelFormat: %w", err)
|
||||
}
|
||||
s.pf = parsePixelFormat(buf[3:19])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleSetEncodings() error {
|
||||
var header [3]byte // 1 padding + 2 number-of-encodings
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read SetEncodings header: %w", err)
|
||||
}
|
||||
numEnc := binary.BigEndian.Uint16(header[1:3])
|
||||
buf := make([]byte, int(numEnc)*4)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if client supports zlib encoding.
|
||||
for i := range int(numEnc) {
|
||||
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
|
||||
if enc == encZlib {
|
||||
s.useZlib = true
|
||||
if s.zlib == nil {
|
||||
s.zlib = newZlibState()
|
||||
}
|
||||
s.log.Debugf("client supports zlib encoding")
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleFBUpdateRequest() error {
|
||||
var req [9]byte
|
||||
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
|
||||
return fmt.Errorf("read FBUpdateRequest: %w", err)
|
||||
}
|
||||
incremental := req[0]
|
||||
|
||||
img, err := s.capturer.Capture()
|
||||
if err != nil {
|
||||
return fmt.Errorf("capture screen: %w", err)
|
||||
}
|
||||
|
||||
if s.recorder != nil {
|
||||
s.recorder.writeFrame(img)
|
||||
}
|
||||
|
||||
if incremental == 1 && s.prevFrame != nil {
|
||||
rects := diffRects(s.prevFrame, img, s.serverW, s.serverH, tileSize)
|
||||
if len(rects) == 0 {
|
||||
// Nothing changed. Back off briefly before responding to reduce
|
||||
// CPU usage when the screen is static. The client re-requests
|
||||
// immediately after receiving our empty response, so without
|
||||
// this delay we'd spin at ~1000fps checking for changes.
|
||||
s.idleFrames++
|
||||
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
s.savePrevFrame(img)
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
s.idleFrames = 0
|
||||
s.savePrevFrame(img)
|
||||
return s.sendDirtyRects(img, rects)
|
||||
}
|
||||
|
||||
// Full update.
|
||||
s.idleFrames = 0
|
||||
s.savePrevFrame(img)
|
||||
return s.sendFullUpdate(img)
|
||||
}
|
||||
|
||||
// savePrevFrame copies img's pixel data into prevFrame. This is necessary
|
||||
// because some capturers (DXGI) reuse the same image buffer across calls,
|
||||
// so a simple pointer assignment would make prevFrame alias the live buffer
|
||||
// and diffRects would always see zero changes.
|
||||
func (s *session) savePrevFrame(img *image.RGBA) {
|
||||
if s.prevFrame == nil || s.prevFrame.Rect != img.Rect {
|
||||
s.prevFrame = image.NewRGBA(img.Rect)
|
||||
}
|
||||
copy(s.prevFrame.Pix, img.Pix)
|
||||
}
|
||||
|
||||
// sendEmptyUpdate sends a FramebufferUpdate with zero rectangles.
|
||||
func (s *session) sendEmptyUpdate() error {
|
||||
var buf [4]byte
|
||||
buf[0] = serverFramebufferUpdate
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf[:])
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) sendFullUpdate(img *image.RGBA) error {
|
||||
w, h := s.serverW, s.serverH
|
||||
|
||||
var buf []byte
|
||||
if s.useZlib && s.zlib != nil {
|
||||
buf = encodeZlibRect(img, s.pf, 0, 0, w, h, s.zlib.w, s.zlib.buf)
|
||||
} else {
|
||||
buf = encodeRawRect(img, s.pf, 0, 0, w, h)
|
||||
}
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) sendDirtyRects(img *image.RGBA, rects [][4]int) error {
|
||||
// Build a multi-rectangle FramebufferUpdate.
|
||||
// Header: type(1) + padding(1) + numRects(2)
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], uint16(len(rects)))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range rects {
|
||||
x, y, w, h := r[0], r[1], r[2], r[3]
|
||||
|
||||
var rectBuf []byte
|
||||
if s.useZlib && s.zlib != nil {
|
||||
rectBuf = encodeZlibRect(img, s.pf, x, y, w, h, s.zlib.w, s.zlib.buf)
|
||||
// encodeZlibRect includes its own FBUpdate header for 1 rect.
|
||||
// For multi-rect, we need just the rect data without the FBUpdate header.
|
||||
// Skip the 4-byte FBUpdate header since we already sent ours.
|
||||
rectBuf = rectBuf[4:]
|
||||
} else {
|
||||
rectBuf = encodeRawRect(img, s.pf, x, y, w, h)
|
||||
rectBuf = rectBuf[4:] // skip FBUpdate header
|
||||
}
|
||||
|
||||
if _, err := s.conn.Write(rectBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleKeyEvent() error {
|
||||
var data [7]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read KeyEvent: %w", err)
|
||||
}
|
||||
down := data[0] == 1
|
||||
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||
s.injector.InjectKey(keysym, down)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handlePointerEvent() error {
|
||||
var data [5]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read PointerEvent: %w", err)
|
||||
}
|
||||
buttonMask := data[0]
|
||||
x := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
y := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
s.injector.InjectPointer(buttonMask, x, y, s.serverW, s.serverH)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleCutText() error {
|
||||
var header [7]byte // 3 padding + 4 length
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read CutText header: %w", err)
|
||||
}
|
||||
length := binary.BigEndian.Uint32(header[3:7])
|
||||
if length > maxCutTextBytes {
|
||||
return fmt.Errorf("cut text too large: %d bytes", length)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read CutText payload: %w", err)
|
||||
}
|
||||
s.injector.SetClipboard(string(buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendServerCutText sends clipboard text from the server to the client.
|
||||
func (s *session) sendServerCutText(text string) error {
|
||||
data := []byte(text)
|
||||
buf := make([]byte, 8+len(data))
|
||||
buf[0] = serverCutText
|
||||
// buf[1:4] = padding (zero)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
|
||||
copy(buf[8:], data)
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ShutdownState tracks VNC virtual session processes for crash recovery.
|
||||
// Persisted by the state manager; on restart, residual processes are killed.
|
||||
type ShutdownState struct {
|
||||
// Processes maps a description to its PID (e.g., "xvfb:50" -> 1234).
|
||||
Processes map[string]int `json:"processes,omitempty"`
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "vnc_sessions_state"
|
||||
}
|
||||
|
||||
// Cleanup kills any residual VNC session processes left from a crash.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
if len(s.Processes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for desc, pid := range s.Processes {
|
||||
if pid <= 0 {
|
||||
continue
|
||||
}
|
||||
if !isOurProcess(pid, desc) {
|
||||
log.Debugf("cleanup:skipping PID %d (%s), not ours", pid, desc)
|
||||
continue
|
||||
}
|
||||
log.Infof("cleanup:killing residual process %d (%s)", pid, desc)
|
||||
// Kill the process group (negative PID) to get children too.
|
||||
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
|
||||
// Try individual process if group kill fails.
|
||||
syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
s.Processes = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// isOurProcess verifies the PID still belongs to a VNC-related process
|
||||
// by checking /proc/<pid>/cmdline (Linux) or the process name.
|
||||
func isOurProcess(pid int, desc string) bool {
|
||||
// Check if the process exists at all.
|
||||
if err := syscall.Kill(pid, 0); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// On Linux, verify via /proc cmdline.
|
||||
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
||||
if err != nil {
|
||||
// No /proc (FreeBSD): trust the PID if the process exists.
|
||||
// PID reuse is unlikely in the short window between crash and restart.
|
||||
return true
|
||||
}
|
||||
|
||||
cmd := string(cmdline)
|
||||
// Match against expected process types.
|
||||
if strings.Contains(desc, "xvfb") || strings.Contains(desc, "xorg") {
|
||||
return strings.Contains(cmd, "Xvfb") || strings.Contains(cmd, "Xorg")
|
||||
}
|
||||
if strings.Contains(desc, "desktop") {
|
||||
return strings.Contains(cmd, "session") || strings.Contains(cmd, "plasma") ||
|
||||
strings.Contains(cmd, "gnome") || strings.Contains(cmd, "xfce") ||
|
||||
strings.Contains(cmd, "dbus-launch")
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
const maxCapturerRetries = 5
|
||||
|
||||
// StubCapturer is a placeholder for platforms without screen capture support.
|
||||
type StubCapturer struct{}
|
||||
|
||||
// Width returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Width() int { return 0 }
|
||||
|
||||
// Height returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Height() int { return 0 }
|
||||
|
||||
// Capture returns an error on unsupported platforms.
|
||||
func (c *StubCapturer) Capture() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("screen capture not supported on this platform")
|
||||
}
|
||||
|
||||
// StubInputInjector is a placeholder for platforms without input injection support.
|
||||
type StubInputInjector struct{}
|
||||
|
||||
// InjectKey is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectKey(_ uint32, _ bool) {}
|
||||
|
||||
// InjectPointer is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectPointer(_ uint8, _, _, _, _ int) {}
|
||||
|
||||
// SetClipboard is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) SetClipboard(_ string) {}
|
||||
|
||||
// GetClipboard returns empty on unsupported platforms.
|
||||
func (s *StubInputInjector) GetClipboard() string { return "" }
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// swizzleBGRAtoRGBA swaps B and R channels in a BGRA pixel buffer in-place.
|
||||
// Operates on uint32 words for throughput: one read-modify-write per pixel.
|
||||
func swizzleBGRAtoRGBA(pix []byte) {
|
||||
n := len(pix) / 4
|
||||
pixels := unsafe.Slice((*uint32)(unsafe.Pointer(&pix[0])), n)
|
||||
for i := range n {
|
||||
p := pixels[i]
|
||||
// p = 0xAABBGGRR (little-endian BGRA in memory: B,G,R,A bytes)
|
||||
// We want 0xAABBGGRR -> 0xAARRGGBB (RGBA in memory: R,G,B,A bytes)
|
||||
// Swap byte 0 (B) and byte 2 (R), keep byte 1 (G) and byte 3 (A).
|
||||
pixels[i] = (p & 0xFF00FF00) | ((p & 0x00FF0000) >> 16) | ((p & 0x000000FF) << 16)
|
||||
}
|
||||
}
|
||||
@@ -1,634 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VirtualSession manages a virtual X11 display (Xvfb) with a desktop session
|
||||
// running as a target user. It implements ScreenCapturer and InputInjector by
|
||||
// delegating to an X11Capturer/X11InputInjector pointed at the virtual display.
|
||||
const sessionIdleTimeout = 5 * time.Minute
|
||||
|
||||
type VirtualSession struct {
|
||||
mu sync.Mutex
|
||||
display string
|
||||
user *user.User
|
||||
uid uint32
|
||||
gid uint32
|
||||
groups []uint32
|
||||
xvfb *exec.Cmd
|
||||
desktop *exec.Cmd
|
||||
poller *X11Poller
|
||||
injector *X11InputInjector
|
||||
log *log.Entry
|
||||
stopped bool
|
||||
clients int
|
||||
idleTimer *time.Timer
|
||||
onIdle func() // called when idle timeout fires or Xvfb dies
|
||||
}
|
||||
|
||||
// StartVirtualSession creates and starts a virtual X11 session for the given user.
|
||||
// Requires root privileges to create sessions as other users.
|
||||
func StartVirtualSession(username string, logger *log.Entry) (*VirtualSession, error) {
|
||||
if os.Getuid() != 0 {
|
||||
return nil, fmt.Errorf("virtual sessions require root privileges")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("Xvfb"); err != nil {
|
||||
if _, err := exec.LookPath("Xorg"); err != nil {
|
||||
return nil, fmt.Errorf("neither Xvfb nor Xorg found (install xvfb or xserver-xorg)")
|
||||
}
|
||||
if !hasDummyDriver() {
|
||||
return nil, fmt.Errorf("Xvfb not found and Xorg dummy driver not installed (install xvfb or xf86-video-dummy)")
|
||||
}
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse uid: %w", err)
|
||||
}
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse gid: %w", err)
|
||||
}
|
||||
|
||||
groups, err := supplementaryGroups(u)
|
||||
if err != nil {
|
||||
logger.Debugf("supplementary groups for %s: %v", username, err)
|
||||
}
|
||||
|
||||
vs := &VirtualSession{
|
||||
user: u,
|
||||
uid: uint32(uid),
|
||||
gid: uint32(gid),
|
||||
groups: groups,
|
||||
log: logger.WithField("vnc_user", username),
|
||||
}
|
||||
|
||||
if err := vs.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) start() error {
|
||||
display, err := findFreeDisplay()
|
||||
if err != nil {
|
||||
return fmt.Errorf("find free display: %w", err)
|
||||
}
|
||||
vs.display = display
|
||||
|
||||
if err := vs.startXvfb(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", vs.display[1:])
|
||||
if err := waitForPath(socketPath, 5*time.Second); err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("wait for X11 socket %s: %w", socketPath, err)
|
||||
}
|
||||
|
||||
// Grant the target user access to the display via xhost.
|
||||
xhostCmd := exec.Command("xhost", "+SI:localuser:"+vs.user.Username)
|
||||
xhostCmd.Env = []string{"DISPLAY=" + vs.display}
|
||||
if out, err := xhostCmd.CombinedOutput(); err != nil {
|
||||
vs.log.Debugf("xhost: %s (%v)", strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
|
||||
vs.poller = NewX11Poller(vs.display)
|
||||
|
||||
injector, err := NewX11InputInjector(vs.display)
|
||||
if err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("create X11 injector for %s: %w", vs.display, err)
|
||||
}
|
||||
vs.injector = injector
|
||||
|
||||
if err := vs.startDesktop(); err != nil {
|
||||
vs.injector.Close()
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("start desktop: %w", err)
|
||||
}
|
||||
|
||||
vs.log.Infof("virtual session started: display=%s user=%s", vs.display, vs.user.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConnect increments the client count and cancels any idle timer.
|
||||
func (vs *VirtualSession) ClientConnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients++
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the client count. When the last client
|
||||
// disconnects, starts an idle timer that destroys the session.
|
||||
func (vs *VirtualSession) ClientDisconnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients--
|
||||
if vs.clients <= 0 {
|
||||
vs.clients = 0
|
||||
vs.log.Infof("no VNC clients connected, session will be destroyed in %s", sessionIdleTimeout)
|
||||
vs.idleTimer = time.AfterFunc(sessionIdleTimeout, vs.idleExpired)
|
||||
}
|
||||
}
|
||||
|
||||
// idleExpired is called by the idle timer. It stops the session and
|
||||
// notifies the session manager via onIdle so it removes us from the map.
|
||||
func (vs *VirtualSession) idleExpired() {
|
||||
vs.log.Info("idle timeout reached, destroying virtual session")
|
||||
vs.Stop()
|
||||
// onIdle acquires sessionManager.mu; safe because Stop() has released vs.mu.
|
||||
if vs.onIdle != nil {
|
||||
vs.onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
// isAlive returns true if the session is running and its X server socket exists.
|
||||
func (vs *VirtualSession) isAlive() bool {
|
||||
vs.mu.Lock()
|
||||
stopped := vs.stopped
|
||||
display := vs.display
|
||||
vs.mu.Unlock()
|
||||
|
||||
if stopped {
|
||||
return false
|
||||
}
|
||||
// Verify the X socket still exists on disk.
|
||||
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", display[1:])
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Capturer returns the screen capturer for this virtual session.
|
||||
func (vs *VirtualSession) Capturer() ScreenCapturer {
|
||||
return vs.poller
|
||||
}
|
||||
|
||||
// Injector returns the input injector for this virtual session.
|
||||
func (vs *VirtualSession) Injector() InputInjector {
|
||||
return vs.injector
|
||||
}
|
||||
|
||||
// Display returns the X11 display string (e.g., ":99").
|
||||
func (vs *VirtualSession) Display() string {
|
||||
return vs.display
|
||||
}
|
||||
|
||||
// Stop terminates the virtual session, killing the desktop and Xvfb.
|
||||
func (vs *VirtualSession) Stop() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
|
||||
if vs.stopped {
|
||||
return
|
||||
}
|
||||
vs.stopped = true
|
||||
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
|
||||
vs.stopDesktop()
|
||||
vs.stopXvfb()
|
||||
|
||||
vs.log.Info("virtual session stopped")
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfb() error {
|
||||
if _, err := exec.LookPath("Xvfb"); err == nil {
|
||||
return vs.startXvfbDirect()
|
||||
}
|
||||
return vs.startXorgDummy()
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfbDirect() error {
|
||||
vs.xvfb = exec.Command("Xvfb", vs.display,
|
||||
"-screen", "0", "1280x800x24",
|
||||
"-ac",
|
||||
"-nolisten", "tcp",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
return fmt.Errorf("start Xvfb on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xvfb started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go vs.monitorXvfb()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startXorgDummy starts Xorg with the dummy video driver as a fallback when
|
||||
// Xvfb is not installed. Most systems with a desktop have Xorg available.
|
||||
func (vs *VirtualSession) startXorgDummy() error {
|
||||
confPath := fmt.Sprintf("/tmp/nbvnc-dummy-%s.conf", vs.display[1:])
|
||||
conf := `Section "Device"
|
||||
Identifier "dummy"
|
||||
Driver "dummy"
|
||||
VideoRam 256000
|
||||
EndSection
|
||||
Section "Screen"
|
||||
Identifier "screen"
|
||||
Device "dummy"
|
||||
DefaultDepth 24
|
||||
SubSection "Display"
|
||||
Depth 24
|
||||
Modes "1280x800"
|
||||
EndSubSection
|
||||
EndSection
|
||||
`
|
||||
if err := os.WriteFile(confPath, []byte(conf), 0644); err != nil {
|
||||
return fmt.Errorf("write Xorg dummy config: %w", err)
|
||||
}
|
||||
|
||||
vs.xvfb = exec.Command("Xorg", vs.display,
|
||||
"-config", confPath,
|
||||
"-noreset",
|
||||
"-nolisten", "tcp",
|
||||
"-ac",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("start Xorg dummy on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xorg (dummy driver) started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go func() {
|
||||
vs.monitorXvfb()
|
||||
os.Remove(confPath)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorXvfb waits for the Xvfb/Xorg process to exit. If it exits
|
||||
// unexpectedly (not via Stop), the session is marked as dead and the
|
||||
// onIdle callback fires so the session manager removes it from the map.
|
||||
// The next GetOrCreate call for this user will create a fresh session.
|
||||
func (vs *VirtualSession) monitorXvfb() {
|
||||
if err := vs.xvfb.Wait(); err != nil {
|
||||
vs.log.Debugf("X server exited: %v", err)
|
||||
}
|
||||
|
||||
vs.mu.Lock()
|
||||
alreadyStopped := vs.stopped
|
||||
if !alreadyStopped {
|
||||
vs.log.Warn("X server exited unexpectedly, marking session as dead")
|
||||
vs.stopped = true
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
vs.stopDesktop()
|
||||
}
|
||||
onIdle := vs.onIdle
|
||||
vs.mu.Unlock()
|
||||
|
||||
if !alreadyStopped && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopXvfb() {
|
||||
if vs.xvfb != nil && vs.xvfb.Process != nil {
|
||||
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGTERM)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startDesktop() error {
|
||||
session := detectDesktopSession()
|
||||
|
||||
// Wrap the desktop command with dbus-launch to provide a session bus.
|
||||
// Without this, most desktop environments (XFCE, MATE, etc.) fail immediately.
|
||||
var args []string
|
||||
if _, err := exec.LookPath("dbus-launch"); err == nil {
|
||||
args = append([]string{"dbus-launch", "--exit-with-session"}, session...)
|
||||
} else {
|
||||
args = session
|
||||
}
|
||||
|
||||
vs.desktop = exec.Command(args[0], args[1:]...)
|
||||
vs.desktop.Dir = vs.user.HomeDir
|
||||
vs.desktop.Env = vs.buildUserEnv()
|
||||
vs.desktop.SysProcAttr = &syscall.SysProcAttr{
|
||||
Credential: &syscall.Credential{
|
||||
Uid: vs.uid,
|
||||
Gid: vs.gid,
|
||||
Groups: vs.groups,
|
||||
},
|
||||
Setsid: true,
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
if err := vs.desktop.Start(); err != nil {
|
||||
return fmt.Errorf("start desktop session (%v): %w", args, err)
|
||||
}
|
||||
vs.log.Infof("desktop session started: %v (pid=%d)", args, vs.desktop.Process.Pid)
|
||||
|
||||
go func() {
|
||||
if err := vs.desktop.Wait(); err != nil {
|
||||
vs.log.Debugf("desktop session exited: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopDesktop() {
|
||||
if vs.desktop != nil && vs.desktop.Process != nil {
|
||||
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGTERM)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) buildUserEnv() []string {
|
||||
return []string{
|
||||
"DISPLAY=" + vs.display,
|
||||
"HOME=" + vs.user.HomeDir,
|
||||
"USER=" + vs.user.Username,
|
||||
"LOGNAME=" + vs.user.Username,
|
||||
"SHELL=" + getUserShell(vs.user.Uid),
|
||||
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"XDG_RUNTIME_DIR=/run/user/" + vs.user.Uid,
|
||||
"DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/" + vs.user.Uid + "/bus",
|
||||
}
|
||||
}
|
||||
|
||||
// detectDesktopSession discovers available desktop sessions from the standard
|
||||
// /usr/share/xsessions/*.desktop files (FreeDesktop standard, used by all
|
||||
// display managers). Falls back to a hardcoded list if no .desktop files found.
|
||||
func detectDesktopSession() []string {
|
||||
// Scan xsessions directories (Linux: /usr/share, FreeBSD: /usr/local/share).
|
||||
for _, dir := range []string{"/usr/share/xsessions", "/usr/local/share/xsessions"} {
|
||||
if cmd := findXSession(dir); cmd != nil {
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try common session commands directly.
|
||||
fallbacks := [][]string{
|
||||
{"startplasma-x11"},
|
||||
{"gnome-session"},
|
||||
{"xfce4-session"},
|
||||
{"mate-session"},
|
||||
{"cinnamon-session"},
|
||||
{"openbox-session"},
|
||||
{"xterm"},
|
||||
}
|
||||
for _, s := range fallbacks {
|
||||
if _, err := exec.LookPath(s[0]); err == nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return []string{"xterm"}
|
||||
}
|
||||
|
||||
// sessionPriority defines preference order for desktop environments.
|
||||
// Lower number = higher priority. Unknown sessions get 100.
|
||||
var sessionPriority = map[string]int{
|
||||
"plasma": 1, // KDE
|
||||
"gnome": 2,
|
||||
"xfce": 3,
|
||||
"mate": 4,
|
||||
"cinnamon": 5,
|
||||
"lxqt": 6,
|
||||
"lxde": 7,
|
||||
"budgie": 8,
|
||||
"openbox": 20,
|
||||
"fluxbox": 21,
|
||||
"i3": 22,
|
||||
"xinit": 50, // generic user session
|
||||
"lightdm": 50,
|
||||
"default": 50,
|
||||
}
|
||||
|
||||
func findXSession(dir string) []string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
cmd string
|
||||
priority int
|
||||
}
|
||||
var candidates []candidate
|
||||
|
||||
for _, e := range entries {
|
||||
if !strings.HasSuffix(e.Name(), ".desktop") {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, e.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
execCmd := ""
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "Exec=") {
|
||||
execCmd = strings.TrimSpace(strings.TrimPrefix(line, "Exec="))
|
||||
break
|
||||
}
|
||||
}
|
||||
if execCmd == "" || execCmd == "default" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine priority from the filename or exec command.
|
||||
pri := 100
|
||||
lower := strings.ToLower(e.Name() + " " + execCmd)
|
||||
for keyword, p := range sessionPriority {
|
||||
if strings.Contains(lower, keyword) && p < pri {
|
||||
pri = p
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, candidate{cmd: execCmd, priority: pri})
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pick the highest priority (lowest number).
|
||||
best := candidates[0]
|
||||
for _, c := range candidates[1:] {
|
||||
if c.priority < best.priority {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the binary exists.
|
||||
parts := strings.Fields(best.cmd)
|
||||
if _, err := exec.LookPath(parts[0]); err != nil {
|
||||
return nil
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// findFreeDisplay scans for an unused X11 display number.
|
||||
func findFreeDisplay() (string, error) {
|
||||
for n := 50; n < 200; n++ {
|
||||
lockFile := fmt.Sprintf("/tmp/.X%d-lock", n)
|
||||
socketFile := fmt.Sprintf("/tmp/.X11-unix/X%d", n)
|
||||
if _, err := os.Stat(lockFile); err == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socketFile); err == nil {
|
||||
continue
|
||||
}
|
||||
return fmt.Sprintf(":%d", n), nil
|
||||
}
|
||||
return "", fmt.Errorf("no free X11 display found (checked :50-:199)")
|
||||
}
|
||||
|
||||
// waitForPath polls until a filesystem path exists or the timeout expires.
|
||||
func waitForPath(path string, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for %s", path)
|
||||
}
|
||||
|
||||
// getUserShell returns the login shell for the given UID.
|
||||
func getUserShell(uid string) string {
|
||||
data, err := os.ReadFile("/etc/passwd")
|
||||
if err != nil {
|
||||
return "/bin/sh"
|
||||
}
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 7 && fields[2] == uid {
|
||||
return fields[6]
|
||||
}
|
||||
}
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
// supplementaryGroups returns the supplementary group IDs for a user.
|
||||
func supplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
gids, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var groups []uint32
|
||||
for _, g := range gids {
|
||||
id, err := strconv.ParseUint(g, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, uint32(id))
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// sessionManager tracks active virtual sessions by username.
|
||||
type sessionManager struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*VirtualSession
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
func newSessionManager(logger *log.Entry) *sessionManager {
|
||||
return &sessionManager{
|
||||
sessions: make(map[string]*VirtualSession),
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreate returns an existing virtual session or creates a new one.
|
||||
// If a previous session for this user is stopped or its X server died, it is replaced.
|
||||
func (sm *sessionManager) GetOrCreate(username string) (vncSession, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if vs, ok := sm.sessions[username]; ok {
|
||||
if vs.isAlive() {
|
||||
return vs, nil
|
||||
}
|
||||
sm.log.Infof("replacing dead virtual session for %s", username)
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
}
|
||||
|
||||
vs, err := StartVirtualSession(username, sm.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vs.onIdle = func() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
if cur, ok := sm.sessions[username]; ok && cur == vs {
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("removed idle virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
sm.sessions[username] = vs
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
// hasDummyDriver checks common paths for the Xorg dummy video driver.
|
||||
func hasDummyDriver() bool {
|
||||
paths := []string{
|
||||
"/usr/lib/xorg/modules/drivers/dummy_drv.so", // Debian/Ubuntu
|
||||
"/usr/lib64/xorg/modules/drivers/dummy_drv.so", // RHEL/Fedora
|
||||
"/usr/local/lib/xorg/modules/drivers/dummy_drv.so", // FreeBSD
|
||||
"/usr/lib/x86_64-linux-gnu/xorg/modules/drivers/dummy_drv.so", // Debian multiarch
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StopAll terminates all active virtual sessions.
|
||||
func (sm *sessionManager) StopAll() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for username, vs := range sm.sessions {
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("stopped virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
//go:embed webplayer.html
|
||||
var webPlayerHTML []byte
|
||||
|
||||
// ServeWebPlayer starts a local HTTP server that serves the recording file
|
||||
// and an HTML player page. Returns the URL to open.
|
||||
func ServeWebPlayer(recPath, listenAddr string) (string, error) {
|
||||
if listenAddr == "" {
|
||||
listenAddr = "localhost:0"
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", listenAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Write(webPlayerHTML) //nolint:errcheck
|
||||
})
|
||||
|
||||
mux.HandleFunc("/recording.rec", func(w http.ResponseWriter, r *http.Request) {
|
||||
f, err := os.Open(recPath)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
fi, _ := f.Stat()
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
http.ServeContent(w, r, "recording.rec", fi.ModTime(), f)
|
||||
})
|
||||
|
||||
url := fmt.Sprintf("http://%s", ln.Addr())
|
||||
|
||||
go http.Serve(ln, mux) //nolint:errcheck
|
||||
|
||||
return url, nil
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>NetBird - VNC Session Recording</title>
|
||||
<style>
|
||||
* { box-sizing: border-box; margin: 0; padding: 0; }
|
||||
body { background: #0d1117; color: #e6edf3; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; display: flex; flex-direction: column; height: 100vh; }
|
||||
|
||||
#header { background: #161b22; padding: 10px 20px; display: flex; align-items: center; gap: 16px; border-bottom: 1px solid #30363d; flex-wrap: wrap; }
|
||||
.logo { display: flex; align-items: center; gap: 8px; }
|
||||
.logo svg { width: 20px; height: 20px; }
|
||||
.logo-text { font-weight: 600; font-size: 14px; color: #f0f6fc; }
|
||||
.logo-badge { font-size: 11px; background: #f4722b; color: #fff; padding: 1px 7px; border-radius: 10px; font-weight: 500; }
|
||||
#rec-info { font-size: 12px; color: #8b949e; display: flex; gap: 6px; flex-wrap: wrap; }
|
||||
#rec-info span { background: #21262d; padding: 2px 8px; border-radius: 4px; }
|
||||
#rec-info .label { color: #6e7681; }
|
||||
|
||||
#controls { background: #161b22; padding: 6px 20px; display: flex; align-items: center; gap: 10px; border-bottom: 1px solid #30363d; }
|
||||
#controls button { background: #21262d; color: #e6edf3; border: 1px solid #30363d; padding: 4px 14px; border-radius: 6px; cursor: pointer; font-size: 14px; min-width: 36px; }
|
||||
#controls button:hover { background: #30363d; border-color: #8b949e; }
|
||||
#controls button.active { background: #f4722b; border-color: #f4722b; color: #fff; }
|
||||
#seek { flex: 1; cursor: pointer; accent-color: #f4722b; height: 4px; margin: 0; padding: 0; }
|
||||
#speed-select { background: #21262d; color: #e6edf3; border: 1px solid #30363d; padding: 3px 6px; border-radius: 4px; font-size: 12px; }
|
||||
#time { font-size: 12px; font-variant-numeric: tabular-nums; min-width: 90px; text-align: center; color: #8b949e; }
|
||||
#frame-info { font-size: 11px; color: #6e7681; }
|
||||
|
||||
#canvas-wrap { flex: 1; display: flex; align-items: center; justify-content: center; overflow: hidden; background: #010409; }
|
||||
canvas { max-width: 100%; max-height: 100%; }
|
||||
|
||||
#footer { background: #161b22; padding: 5px 20px; font-size: 11px; color: #484f58; border-top: 1px solid #30363d; display: flex; justify-content: space-between; }
|
||||
#status { font-size: 12px; color: #8b949e; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="header">
|
||||
<div class="logo">
|
||||
<svg width="24" height="18" viewBox="0 0 31 23" fill="none"><path d="M21.4631 0.523438C17.8173 0.857913 16.0028 2.95675 15.3171 4.01871L4.66406 22.4734H17.5163L30.1929 0.523438H21.4631Z" fill="#F68330"/><path d="M17.5265 22.4737L0 3.88525C0 3.88525 19.8177 -1.44128 21.7493 15.1738L17.5265 22.4737Z" fill="#F68330"/><path d="M14.9236 4.70563L9.54688 14.0208L17.5158 22.4747L21.7385 15.158C21.0696 9.44682 18.2851 6.32784 14.9236 4.69727" fill="#F05252"/></svg>
|
||||
<span class="logo-text">NetBird</span>
|
||||
<span class="logo-badge">VNC Session Recording</span>
|
||||
</div>
|
||||
<div id="rec-info"></div>
|
||||
</div>
|
||||
<div id="controls">
|
||||
<button id="playBtn" onclick="togglePlay()" title="Space">▶</button>
|
||||
<input type="range" id="seek" min="0" max="1000" value="0" oninput="seekTo(this.value)">
|
||||
<span id="time">0:00 / 0:00</span>
|
||||
<select id="speed-select" onchange="setSpeed(this.value)" title="Playback speed">
|
||||
<option value="0.25">0.25x</option>
|
||||
<option value="0.5">0.5x</option>
|
||||
<option value="1" selected>1x</option>
|
||||
<option value="2">2x</option>
|
||||
<option value="4">4x</option>
|
||||
<option value="8">8x</option>
|
||||
</select>
|
||||
<span id="frame-info"></span>
|
||||
<span id="status">Loading...</span>
|
||||
</div>
|
||||
<div id="canvas-wrap"><canvas id="canvas"></canvas></div>
|
||||
<div id="footer">
|
||||
<span>Space: play/pause | Left/Right: seek 5s | Scroll: speed</span>
|
||||
<span id="file-info"></span>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const canvas = document.getElementById('canvas');
|
||||
const ctx = canvas.getContext('2d');
|
||||
const seekBar = document.getElementById('seek');
|
||||
const timeEl = document.getElementById('time');
|
||||
const statusEl = document.getElementById('status');
|
||||
const recInfoEl = document.getElementById('rec-info');
|
||||
const frameInfoEl = document.getElementById('frame-info');
|
||||
const fileInfoEl = document.getElementById('file-info');
|
||||
const playBtn = document.getElementById('playBtn');
|
||||
|
||||
let frames = []; // { offsetMs, bitmap }
|
||||
let header = null;
|
||||
let playing = false;
|
||||
let speed = 1;
|
||||
let startTime = 0;
|
||||
let pauseOffset = 0;
|
||||
let currentFrame = 0;
|
||||
let animId = null;
|
||||
let durationMs = 0;
|
||||
|
||||
function fmt(ms) {
|
||||
const s = Math.floor(ms / 1000);
|
||||
const m = Math.floor(s / 60);
|
||||
const h = Math.floor(m / 60);
|
||||
if (h > 0) return `${h}:${String(m % 60).padStart(2, '0')}:${String(s % 60).padStart(2, '0')}`;
|
||||
return `${m}:${String(s % 60).padStart(2, '0')}`;
|
||||
}
|
||||
|
||||
function fmtSize(bytes) {
|
||||
if (bytes >= 1048576) return (bytes / 1048576).toFixed(1) + ' MB';
|
||||
if (bytes >= 1024) return (bytes / 1024).toFixed(1) + ' KB';
|
||||
return bytes + ' B';
|
||||
}
|
||||
|
||||
async function load() {
|
||||
statusEl.textContent = 'Fetching...';
|
||||
const resp = await fetch('/recording.rec');
|
||||
const buf = await resp.arrayBuffer();
|
||||
const view = new DataView(buf);
|
||||
|
||||
const magic = new TextDecoder().decode(new Uint8Array(buf, 0, 6));
|
||||
if (magic !== 'NBVNC\x01') { statusEl.textContent = 'Invalid recording file'; return; }
|
||||
|
||||
const width = view.getUint16(6);
|
||||
const height = view.getUint16(8);
|
||||
const startMs = Number(view.getBigUint64(10));
|
||||
const metaLen = view.getUint32(18);
|
||||
const metaJSON = new TextDecoder().decode(new Uint8Array(buf, 22, metaLen));
|
||||
const meta = JSON.parse(metaJSON);
|
||||
|
||||
header = { width, height, startMs, meta };
|
||||
canvas.width = width;
|
||||
canvas.height = height;
|
||||
|
||||
const dateStr = new Date(startMs).toLocaleString();
|
||||
const parts = [];
|
||||
if (meta.mode) parts.push(`<span><span class="label">Type:</span> vnc (${meta.mode})</span>`);
|
||||
if (meta.remote_addr) parts.push(`<span><span class="label">Remote:</span> ${meta.remote_addr}</span>`);
|
||||
if (meta.jwt_user) parts.push(`<span><span class="label">JWT:</span> ${meta.jwt_user}</span>`);
|
||||
if (meta.user) parts.push(`<span><span class="label">User:</span> ${meta.user}</span>`);
|
||||
parts.push(`<span><span class="label">Date:</span> ${dateStr}</span>`);
|
||||
parts.push(`<span>${width}x${height}</span>`);
|
||||
recInfoEl.innerHTML = parts.join('');
|
||||
fileInfoEl.textContent = fmtSize(buf.byteLength);
|
||||
document.title = `NetBird - VNC Session Recording - ${meta.remote_addr || ''}`;
|
||||
|
||||
// Parse frame offsets (fast pass, no decoding)
|
||||
const rawFrames = [];
|
||||
let offset = 22 + metaLen;
|
||||
while (offset + 8 <= buf.byteLength) {
|
||||
const offsetMs = view.getUint32(offset);
|
||||
const pngLen = view.getUint32(offset + 4);
|
||||
offset += 8;
|
||||
if (offset + pngLen > buf.byteLength) break;
|
||||
rawFrames.push({ offsetMs, start: offset, length: pngLen });
|
||||
offset += pngLen;
|
||||
}
|
||||
|
||||
if (rawFrames.length === 0) { statusEl.textContent = 'No frames in recording'; return; }
|
||||
|
||||
// Handle encrypted recordings
|
||||
let decryptKey = null;
|
||||
if (meta.encrypted) {
|
||||
const privKeyB64 = prompt('This recording is encrypted.\nPaste your base64-encoded X25519 private key:');
|
||||
if (!privKeyB64) { statusEl.textContent = 'Decryption key required'; return; }
|
||||
try {
|
||||
decryptKey = await deriveDecryptKey(privKeyB64, meta.ephemeral_key);
|
||||
} catch (e) {
|
||||
statusEl.textContent = 'Decryption key error: ' + e.message;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Decode PNG frames to ImageData. We decode bitmaps in parallel batches,
|
||||
// then draw them sequentially to avoid OffscreenCanvas races.
|
||||
const offscreen = new OffscreenCanvas(width, height);
|
||||
const offCtx = offscreen.getContext('2d');
|
||||
const batchSize = 20;
|
||||
for (let i = 0; i < rawFrames.length; i += batchSize) {
|
||||
const batch = rawFrames.slice(i, i + batchSize);
|
||||
const bitmaps = await Promise.all(batch.map(async (f, batchIdx) => {
|
||||
let pngData = new Uint8Array(buf, f.start, f.length);
|
||||
if (decryptKey) {
|
||||
const frameIdx = i + batchIdx;
|
||||
pngData = await decryptFrame(decryptKey, pngData, frameIdx);
|
||||
}
|
||||
const blob = new Blob([pngData], { type: 'image/png' });
|
||||
return createImageBitmap(blob);
|
||||
}));
|
||||
for (let j = 0; j < bitmaps.length; j++) {
|
||||
offCtx.drawImage(bitmaps[j], 0, 0);
|
||||
bitmaps[j].close();
|
||||
frames.push({ offsetMs: batch[j].offsetMs, imgData: offCtx.getImageData(0, 0, width, height) });
|
||||
}
|
||||
statusEl.textContent = `Loading ${frames.length}/${rawFrames.length}`;
|
||||
if (i % (batchSize * 3) === 0) await new Promise(r => setTimeout(r, 0));
|
||||
}
|
||||
|
||||
const firstMs = frames[0].offsetMs;
|
||||
durationMs = frames[frames.length - 1].offsetMs;
|
||||
seekBar.min = firstMs;
|
||||
seekBar.max = durationMs;
|
||||
timeEl.textContent = `0:00 / ${fmt(durationMs)}`;
|
||||
statusEl.textContent = `${frames.length} frames, ${fmt(durationMs)}`;
|
||||
renderFrame(0);
|
||||
}
|
||||
|
||||
function renderFrame(idx) {
|
||||
if (idx < 0 || idx >= frames.length) return;
|
||||
currentFrame = idx;
|
||||
const frame = frames[idx];
|
||||
ctx.putImageData(frame.imgData, 0, 0);
|
||||
seekBar.value = frame.offsetMs;
|
||||
timeEl.textContent = `${fmt(frame.offsetMs)} / ${fmt(durationMs)}`;
|
||||
frameInfoEl.textContent = `${idx + 1}/${frames.length}`;
|
||||
}
|
||||
|
||||
function togglePlay() { playing ? pause() : play(); }
|
||||
|
||||
function play() {
|
||||
if (frames.length === 0) return;
|
||||
playing = true;
|
||||
playBtn.innerHTML = '▮▮';
|
||||
playBtn.classList.add('active');
|
||||
if (currentFrame >= frames.length - 1) { currentFrame = 0; pauseOffset = 0; }
|
||||
startTime = performance.now() - pauseOffset / speed;
|
||||
tick();
|
||||
}
|
||||
|
||||
function pause() {
|
||||
playing = false;
|
||||
playBtn.innerHTML = '▶';
|
||||
playBtn.classList.remove('active');
|
||||
if (animId) { cancelAnimationFrame(animId); animId = null; }
|
||||
pauseOffset = frames[currentFrame].offsetMs;
|
||||
}
|
||||
|
||||
function tick() {
|
||||
if (!playing) return;
|
||||
const targetMs = (performance.now() - startTime) * speed;
|
||||
while (currentFrame < frames.length - 1 && frames[currentFrame + 1].offsetMs <= targetMs) currentFrame++;
|
||||
renderFrame(currentFrame);
|
||||
if (currentFrame >= frames.length - 1) { pause(); return; }
|
||||
animId = requestAnimationFrame(tick);
|
||||
}
|
||||
|
||||
function seekTo(val) {
|
||||
const ms = parseInt(val);
|
||||
let idx = 0;
|
||||
for (let i = 0; i < frames.length; i++) {
|
||||
if (frames[i].offsetMs <= ms) idx = i; else break;
|
||||
}
|
||||
renderFrame(idx);
|
||||
pauseOffset = frames[idx].offsetMs;
|
||||
if (playing) startTime = performance.now() - pauseOffset / speed;
|
||||
}
|
||||
|
||||
function setSpeed(val) {
|
||||
const old = speed;
|
||||
speed = parseFloat(val);
|
||||
if (playing) startTime = performance.now() - (performance.now() - startTime) * old / speed;
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', e => {
|
||||
if (e.code === 'Space') { e.preventDefault(); togglePlay(); }
|
||||
if (e.code === 'ArrowRight') seekTo(Math.min(durationMs, frames[currentFrame].offsetMs + 5000));
|
||||
if (e.code === 'ArrowLeft') seekTo(Math.max(0, frames[currentFrame].offsetMs - 5000));
|
||||
});
|
||||
document.addEventListener('wheel', e => {
|
||||
const sel = document.getElementById('speed-select');
|
||||
const idx = sel.selectedIndex + (e.deltaY > 0 ? -1 : 1);
|
||||
if (idx >= 0 && idx < sel.options.length) { sel.selectedIndex = idx; setSpeed(sel.value); }
|
||||
}, { passive: true });
|
||||
|
||||
// Crypto helpers for encrypted recordings (X25519 ECDH + HKDF + AES-256-GCM)
|
||||
async function deriveDecryptKey(privKeyB64, ephPubB64) {
|
||||
const privBytes = Uint8Array.from(atob(privKeyB64), c => c.charCodeAt(0));
|
||||
const ephPubBytes = Uint8Array.from(atob(ephPubB64), c => c.charCodeAt(0));
|
||||
|
||||
const privKey = await crypto.subtle.importKey('raw', privBytes, { name: 'X25519' }, false, ['deriveBits']);
|
||||
const ephPub = await crypto.subtle.importKey('raw', ephPubBytes, { name: 'X25519' }, false, []);
|
||||
|
||||
const shared = await crypto.subtle.deriveBits({ name: 'X25519', public: ephPub }, privKey, 256);
|
||||
|
||||
// HKDF-SHA256 with salt=ephemeralPub, info="netbird-recording" (matches Go side)
|
||||
const hkdfKey = await crypto.subtle.importKey('raw', shared, 'HKDF', false, ['deriveKey']);
|
||||
const aesKey = await crypto.subtle.deriveKey(
|
||||
{ name: 'HKDF', hash: 'SHA-256', salt: ephPubBytes, info: new TextEncoder().encode('netbird-recording') },
|
||||
hkdfKey,
|
||||
{ name: 'AES-GCM', length: 256 },
|
||||
false,
|
||||
['decrypt'],
|
||||
);
|
||||
return aesKey;
|
||||
}
|
||||
|
||||
async function decryptFrame(key, ciphertext, frameIndex) {
|
||||
const nonce = new Uint8Array(12);
|
||||
new DataView(nonce.buffer).setUint32(0, frameIndex, true); // little-endian u64, upper 4 bytes zero
|
||||
const plain = await crypto.subtle.decrypt({ name: 'AES-GCM', iv: nonce }, key, ciphertext);
|
||||
return new Uint8Array(plain);
|
||||
}
|
||||
|
||||
load().catch(e => { statusEl.textContent = 'Error: ' + e.message; console.error(e); });
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,174 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>VNC Test</title>
|
||||
<style>
|
||||
body { margin: 0; background: #111; color: #eee; font-family: monospace; font-size: 13px; }
|
||||
#toolbar { position: fixed; top: 0; left: 0; right: 0; z-index: 10; background: #222; padding: 4px 8px; display: flex; gap: 8px; align-items: center; }
|
||||
#toolbar button { padding: 4px 12px; cursor: pointer; background: #444; color: #eee; border: 1px solid #666; border-radius: 3px; }
|
||||
#toolbar button:hover { background: #555; }
|
||||
#toolbar #status { flex: 1; }
|
||||
#vnc-container { width: 100vw; height: calc(100vh - 28px); margin-top: 28px; }
|
||||
#log { position: fixed; bottom: 0; left: 0; right: 0; max-height: 150px; overflow-y: auto; background: rgba(0,0,0,0.85); padding: 4px 8px; font-size: 11px; z-index: 10; display: none; }
|
||||
#log.visible { display: block; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="toolbar">
|
||||
<span id="status">Loading WASM...</span>
|
||||
<button onclick="sendCAD()">Ctrl+Alt+Del</button>
|
||||
<button onclick="document.getElementById('log').classList.toggle('visible')">Log</button>
|
||||
</div>
|
||||
<div id="vnc-container"></div>
|
||||
<div id="log"></div>
|
||||
|
||||
<script>
|
||||
const params = new URLSearchParams(location.search);
|
||||
const HOST = params.get('host') || '';
|
||||
const PORT = params.get('port') || '5900';
|
||||
const MODE = params.get('mode') || 'attach'; // 'attach' or 'session'
|
||||
const USER = params.get('user') || '';
|
||||
const SETUP_KEY = params.get('setup_key') || '64BB8FF4-5A96-488F-B0AE-316555E916B0';
|
||||
const MGMT_URL = params.get('mgmt') || 'http://192.168.100.1:8080';
|
||||
|
||||
const statusEl = document.getElementById('status');
|
||||
const logEl = document.getElementById('log');
|
||||
function addLog(msg) {
|
||||
const line = document.createElement('div');
|
||||
line.textContent = `[${new Date().toISOString().slice(11,23)}] ${msg}`;
|
||||
logEl.appendChild(line);
|
||||
logEl.scrollTop = logEl.scrollHeight;
|
||||
console.log('[vnc-test]', msg);
|
||||
}
|
||||
function setStatus(s) { statusEl.textContent = s; addLog(s); }
|
||||
|
||||
let rfbInstance = null;
|
||||
window.sendCAD = () => { if (rfbInstance) { rfbInstance.sendCtrlAltDel(); addLog('Sent Ctrl+Alt+Del'); } };
|
||||
|
||||
// VNC WebSocket proxy (bridges noVNC WebSocket API to Go WASM tunnel)
|
||||
class VNCProxyWS extends EventTarget {
|
||||
constructor(url) {
|
||||
super();
|
||||
this.url = url;
|
||||
this.readyState = 0;
|
||||
this.protocol = '';
|
||||
this.extensions = '';
|
||||
this.bufferedAmount = 0;
|
||||
this.binaryType = 'arraybuffer';
|
||||
this.onopen = null; this.onclose = null; this.onerror = null; this.onmessage = null;
|
||||
const match = url.match(/vnc\.proxy\.local\/(.+)/);
|
||||
this._proxyID = match ? match[1] : 'default';
|
||||
setTimeout(() => this._connect(), 0);
|
||||
}
|
||||
get CONNECTING() { return 0; } get OPEN() { return 1; } get CLOSING() { return 2; } get CLOSED() { return 3; }
|
||||
_connect() {
|
||||
try {
|
||||
const handler = window[`handleVNCWebSocket_${this._proxyID}`];
|
||||
if (!handler) throw new Error(`No VNC handler for ${this._proxyID}`);
|
||||
handler(this);
|
||||
this.readyState = 1;
|
||||
const ev = new Event('open');
|
||||
if (this.onopen) this.onopen(ev);
|
||||
this.dispatchEvent(ev);
|
||||
} catch (err) {
|
||||
addLog(`WS proxy error: ${err.message}`);
|
||||
this.readyState = 3;
|
||||
}
|
||||
}
|
||||
receiveFromGo(data) {
|
||||
const ev = new MessageEvent('message', { data });
|
||||
if (this.onmessage) this.onmessage(ev);
|
||||
this.dispatchEvent(ev);
|
||||
}
|
||||
send(data) {
|
||||
if (this.readyState !== 1) return;
|
||||
let u8;
|
||||
if (data instanceof ArrayBuffer) u8 = new Uint8Array(data);
|
||||
else if (data instanceof Uint8Array) u8 = data;
|
||||
else if (typeof data === 'string') u8 = new TextEncoder().encode(data);
|
||||
else if (data.buffer) u8 = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
||||
else return;
|
||||
if (this.onGoMessage) this.onGoMessage(u8);
|
||||
}
|
||||
close(code, reason) {
|
||||
if (this.readyState >= 2) return;
|
||||
this.readyState = 2;
|
||||
if (this.onGoClose) this.onGoClose();
|
||||
setTimeout(() => {
|
||||
this.readyState = 3;
|
||||
const ev = new CloseEvent('close', { code: code||1000, reason: reason||'', wasClean: true });
|
||||
if (this.onclose) this.onclose(ev);
|
||||
this.dispatchEvent(ev);
|
||||
}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
async function main() {
|
||||
if (!HOST) { setStatus('Usage: ?host=<peer_ip>&setup_key=<key>[&mode=session&user=alice]'); return; }
|
||||
|
||||
// Install WS proxy before anything creates WebSockets
|
||||
const OrigWS = window.WebSocket;
|
||||
window.WebSocket = new Proxy(OrigWS, {
|
||||
construct(target, args) {
|
||||
if (args[0] && args[0].includes('vnc.proxy.local')) return new VNCProxyWS(args[0]);
|
||||
return new target(args[0], args[1]);
|
||||
}
|
||||
});
|
||||
|
||||
// Load WASM
|
||||
setStatus('Loading WASM runtime...');
|
||||
await new Promise((resolve, reject) => {
|
||||
const s = document.createElement('script');
|
||||
s.src = '/wasm_exec.js'; s.onload = resolve; s.onerror = reject;
|
||||
document.head.appendChild(s);
|
||||
});
|
||||
|
||||
setStatus('Loading NetBird WASM...');
|
||||
const go = new Go();
|
||||
const wasm = await WebAssembly.instantiateStreaming(fetch('/netbird.wasm'), go.importObject);
|
||||
go.run(wasm.instance);
|
||||
const t0 = Date.now();
|
||||
while (!window.NetBirdClient && Date.now() - t0 < 10000) await new Promise(r => setTimeout(r, 100));
|
||||
if (!window.NetBirdClient) { setStatus('WASM init timeout'); return; }
|
||||
addLog('WASM ready');
|
||||
|
||||
// Connect NetBird with setup key
|
||||
setStatus('Connecting NetBird...');
|
||||
let client;
|
||||
try {
|
||||
client = await window.NetBirdClient({
|
||||
setupKey: SETUP_KEY,
|
||||
managementURL: MGMT_URL,
|
||||
logLevel: 'debug',
|
||||
});
|
||||
addLog('Client created, starting...');
|
||||
await client.start();
|
||||
addLog('NetBird connected');
|
||||
} catch (err) {
|
||||
setStatus('NetBird error: ' + (err && err.message ? err.message : String(err)));
|
||||
return;
|
||||
}
|
||||
|
||||
// Create VNC proxy
|
||||
setStatus(`Creating VNC proxy (mode=${MODE}${USER ? ', user=' + USER : ''})...`);
|
||||
const proxyURL = await client.createVNCProxy(HOST, PORT, MODE, USER);
|
||||
addLog(`Proxy: ${proxyURL}`);
|
||||
|
||||
// Connect noVNC
|
||||
setStatus('Connecting VNC...');
|
||||
const { default: RFB } = await import('/novnc-pkg/core/rfb.js');
|
||||
const container = document.getElementById('vnc-container');
|
||||
rfbInstance = new RFB(container, proxyURL, { wsProtocols: [] });
|
||||
rfbInstance.scaleViewport = true;
|
||||
rfbInstance.resizeSession = false;
|
||||
rfbInstance.showDotCursor = true;
|
||||
rfbInstance.addEventListener('connect', () => setStatus(`Connected: ${HOST}`));
|
||||
rfbInstance.addEventListener('disconnect', e => setStatus(`Disconnected${e.detail?.clean ? '' : ' (unexpected)'}`));
|
||||
rfbInstance.addEventListener('credentialsrequired', () => rfbInstance.sendCredentials({ password: '' }));
|
||||
window.rfb = rfbInstance;
|
||||
}
|
||||
|
||||
main().catch(err => { setStatus(`Error: ${err.message}`); console.error(err); });
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,44 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
// Simple file server for the VNC test page.
|
||||
// Usage: go run serve.go
|
||||
// Then open: http://localhost:9090?host=100.0.23.250
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Serve from the dashboard's public dir (has wasm, noVNC, etc.)
|
||||
dashboardPublic := os.Getenv("DASHBOARD_PUBLIC")
|
||||
if dashboardPublic == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
dashboardPublic = filepath.Join(home, "dev", "dashboard", "public")
|
||||
}
|
||||
|
||||
// Serve test page index.html from this directory
|
||||
testDir, _ := os.Getwd()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
// Test page itself
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
|
||||
http.ServeFile(w, r, filepath.Join(testDir, "index.html"))
|
||||
return
|
||||
}
|
||||
// Everything else from dashboard public (wasm, noVNC, etc.)
|
||||
http.FileServer(http.Dir(dashboardPublic)).ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
addr := ":9090"
|
||||
fmt.Printf("VNC test page: http://localhost%s?host=<peer_ip>\n", addr)
|
||||
fmt.Printf("Serving assets from: %s\n", dashboardPublic)
|
||||
if err := http.ListenAndServe(addr, mux); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "listen: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/vnc"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -317,13 +317,8 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
|
||||
// JS signature: createVNCProxy(hostname, port, mode?, username?, jwt?, sessionID?)
|
||||
// mode: "attach" (default) or "session"
|
||||
// username: required when mode is "session"
|
||||
// jwt: authentication token (from OIDC session)
|
||||
// sessionID: Windows session ID (0 = console/auto)
|
||||
func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
// createRDPProxyMethod creates the RDP proxy method
|
||||
func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
@@ -340,25 +335,8 @@ func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
mode := "attach"
|
||||
username := ""
|
||||
jwtToken := ""
|
||||
var sessionID uint32
|
||||
if len(args) > 2 && args[2].Type() == js.TypeString {
|
||||
mode = args[2].String()
|
||||
}
|
||||
if len(args) > 3 && args[3].Type() == js.TypeString {
|
||||
username = args[3].String()
|
||||
}
|
||||
if len(args) > 4 && args[4].Type() == js.TypeString {
|
||||
jwtToken = args[4].String()
|
||||
}
|
||||
if len(args) > 5 && args[5].Type() == js.TypeNumber {
|
||||
sessionID = uint32(args[5].Int())
|
||||
}
|
||||
|
||||
proxy := vnc.NewVNCProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String(), mode, username, jwtToken, sessionID)
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -537,7 +515,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createVNCProxy"] = createVNCProxyMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
|
||||
107
client/wasm/internal/rdp/cert_validation.go
Normal file
107
client/wasm/internal/rdp/cert_validation.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
|
||||
return false, fmt.Errorf("certificate validation handler not configured")
|
||||
}
|
||||
|
||||
certInfo := js.Global().Get("Object").New()
|
||||
certInfo.Set("ServerAddr", conn.destination)
|
||||
|
||||
certArray := js.Global().Get("Array").New()
|
||||
for i, certBytes := range certChain {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
|
||||
js.CopyBytesToJS(uint8Array, certBytes)
|
||||
certArray.SetIndex(i, uint8Array)
|
||||
}
|
||||
certInfo.Set("ServerCertChain", certArray)
|
||||
if len(certChain) > 0 {
|
||||
cert, err := x509.ParseCertificate(certChain[0])
|
||||
if err == nil {
|
||||
info := js.Global().Get("Object").New()
|
||||
info.Set("subject", cert.Subject.String())
|
||||
info.Set("issuer", cert.Issuer.String())
|
||||
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
|
||||
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
|
||||
info.Set("serialNumber", cert.SerialNumber.String())
|
||||
certInfo.Set("CertificateInfo", info)
|
||||
}
|
||||
}
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result {
|
||||
log.Info("Certificate accepted by user")
|
||||
} else {
|
||||
log.Info("Certificate rejected by user")
|
||||
}
|
||||
return result, nil
|
||||
case err := <-errorChan:
|
||||
return false, err
|
||||
case <-time.After(certValidationTimeout):
|
||||
return false, fmt.Errorf("certificate validation timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
|
||||
accepted, err := p.validateCertificateWithJS(conn, certChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !accepted {
|
||||
return fmt.Errorf("certificate rejected by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||
if requiresCredSSP {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
} else {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
344
client/wasm/internal/rdp/rdcleanpath.go
Normal file
344
client/wasm/internal/rdp/rdcleanpath.go
Normal file
@@ -0,0 +1,344 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
|
||||
rdpDialTimeout = 15 * time.Second
|
||||
|
||||
GeneralErrorCode = 1
|
||||
WSAETimedOut = 10060
|
||||
WSAEConnRefused = 10061
|
||||
WSAEConnAborted = 10053
|
||||
WSAEConnReset = 10054
|
||||
WSAEGenericError = 10050
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathErr struct {
|
||||
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type proxyConnection struct {
|
||||
id string
|
||||
destination string
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
func NewRDCleanPathProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *RDCleanPathProxy {
|
||||
return &RDCleanPathProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*proxyConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]string)
|
||||
}
|
||||
p.destinations[proxyID] = destination
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
|
||||
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
destination := p.destinations[proxyID]
|
||||
p.mu.Unlock()
|
||||
|
||||
if destination == "" {
|
||||
log.Errorf("No destination found for proxy ID: %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Don't defer cancel here - it will be called by cleanupConnection
|
||||
|
||||
conn := &proxyConnection{
|
||||
id: proxyID,
|
||||
destination: destination,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
|
||||
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
|
||||
if conn.rdpConn != nil || conn.tlsConn != nil {
|
||||
p.forwardToRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
|
||||
var pdu RDCleanPathPDU
|
||||
_, err := asn1.Unmarshal(bytes, &pdu)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
|
||||
n := len(bytes)
|
||||
if n > 20 {
|
||||
n = 20
|
||||
}
|
||||
log.Warnf("First %d bytes: %x", n, bytes[:n])
|
||||
|
||||
if len(bytes) > 0 && bytes[0] == 0x03 {
|
||||
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
|
||||
go p.handleDirectRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go p.processRDCleanPathPDU(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
|
||||
var writer io.Writer
|
||||
var connType string
|
||||
|
||||
if conn.tlsConn != nil {
|
||||
writer = conn.tlsConn
|
||||
connType = "TLS"
|
||||
} else if conn.rdpConn != nil {
|
||||
writer = conn.rdpConn
|
||||
connType = "TCP"
|
||||
} else {
|
||||
log.Error("No RDP connection available")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := writer.Write(bytes); err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
|
||||
defer p.cleanupConnection(conn)
|
||||
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, response[:n])
|
||||
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func errorToWSACode(err error) int16 {
|
||||
if err == nil {
|
||||
return WSAEGenericError
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return WSAEConnAborted
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return WSAEConnReset
|
||||
}
|
||||
return WSAEGenericError
|
||||
}
|
||||
|
||||
func newWSAError(err error) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
WSALastError: errorToWSACode(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
HTTPStatusCode: statusCode,
|
||||
},
|
||||
}
|
||||
}
|
||||
244
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
244
client/wasm/internal/rdp/rdcleanpath_handlers.go
Normal file
@@ -0,0 +1,244 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||
protocolSSL = 0x00000001
|
||||
protocolHybridEx = 0x00000008
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||
return
|
||||
}
|
||||
|
||||
destination := conn.destination
|
||||
if pdu.Destination != "" {
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
// RDP always starts with X.224 negotiation, then determines if TLS is needed
|
||||
// Modern RDP (since Windows Vista/2008) typically requires TLS
|
||||
// The X.224 Connection Confirm response will indicate if TLS is required
|
||||
// For now, we'll attempt TLS for all connections as it's the modern default
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||
const minResponseLength = 19
|
||||
|
||||
if len(x224Response) < minResponseLength {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
// Per X.224 specification:
|
||||
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
if x224Response[11] == 0x02 {
|
||||
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||
|
||||
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||
return hasNLA, flags, true
|
||||
}
|
||||
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||
if detected {
|
||||
if requiresCredSSP {
|
||||
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||
} else {
|
||||
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("TLS handshake successful")
|
||||
|
||||
// Certificate validation happens during handshake via VerifyConnection callback
|
||||
var certChain [][]byte
|
||||
connState := tlsConn.ConnectionState()
|
||||
if len(connState.PeerCertificates) > 0 {
|
||||
for _, cert := range connState.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
ServerCertChain: certChain,
|
||||
}
|
||||
|
||||
if len(x224Response) > 0 {
|
||||
responsePDU.X224ConnectionPDU = x224Response
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
|
||||
log.Debug("Starting TLS forwarding")
|
||||
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
|
||||
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TLS connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
errChan <- io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
if data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
msgChan <- bytes
|
||||
}
|
||||
return nil
|
||||
})
|
||||
defer handler.Release()
|
||||
|
||||
conn.wsHandlers.Set("onceGoMessage", handler)
|
||||
|
||||
select {
|
||||
case msg := <-msgChan:
|
||||
return msg, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-conn.ctx.Done():
|
||||
return nil, conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := p.readWebSocketMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from WebSocket: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from %s: %v", connType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,360 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package vnc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
vncProxyHost = "vnc.proxy.local"
|
||||
vncProxyScheme = "ws"
|
||||
vncDialTimeout = 15 * time.Second
|
||||
|
||||
// Connection modes matching server/server.go constants.
|
||||
modeAttach byte = 0
|
||||
modeSession byte = 1
|
||||
)
|
||||
|
||||
// VNCProxy bridges WebSocket connections from noVNC in the browser
|
||||
// to TCP VNC server connections through the NetBird tunnel.
|
||||
type VNCProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*vncConnection
|
||||
destinations map[string]vncDestination
|
||||
// pendingHandlers holds the js.Func for handleVNCWebSocket_<id> between
|
||||
// CreateProxy and handleWebSocketConnection so we can move it onto the
|
||||
// vncConnection for later release.
|
||||
pendingHandlers map[string]js.Func
|
||||
mu sync.Mutex
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
type vncDestination struct {
|
||||
address string
|
||||
mode byte
|
||||
username string
|
||||
jwt string
|
||||
sessionID uint32 // Windows session ID (0 = auto/console)
|
||||
}
|
||||
|
||||
type vncConnection struct {
|
||||
id string
|
||||
destination vncDestination
|
||||
mu sync.Mutex
|
||||
vncConn net.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||
// global handle map and MUST be released, otherwise every connection
|
||||
// leaks the Go memory the closure captures.
|
||||
wsHandlerFn js.Func
|
||||
onMessageFn js.Func
|
||||
onCloseFn js.Func
|
||||
}
|
||||
|
||||
// NewVNCProxy creates a new VNC proxy.
|
||||
func NewVNCProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *VNCProxy {
|
||||
return &VNCProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*vncConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given VNC destination.
|
||||
// mode is "attach" (capture current display) or "session" (virtual session).
|
||||
// username is required for session mode.
|
||||
// Returns a JS Promise that resolves to the WebSocket proxy URL.
|
||||
func (p *VNCProxy) CreateProxy(hostname, port, mode, username, jwt string, sessionID uint32) js.Value {
|
||||
address := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
var m byte
|
||||
if mode == "session" {
|
||||
m = modeSession
|
||||
}
|
||||
|
||||
dest := vncDestination{
|
||||
address: address,
|
||||
mode: m,
|
||||
username: username,
|
||||
jwt: jwt,
|
||||
sessionID: sessionID,
|
||||
}
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("vnc_proxy_%d", p.nextID.Add(1))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]vncDestination)
|
||||
}
|
||||
p.destinations[proxyID] = dest
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", vncProxyScheme, vncProxyHost, proxyID)
|
||||
|
||||
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
p.handleWebSocketConnection(args[0], proxyID)
|
||||
return nil
|
||||
})
|
||||
p.mu.Lock()
|
||||
if p.pendingHandlers == nil {
|
||||
p.pendingHandlers = make(map[string]js.Func)
|
||||
}
|
||||
p.pendingHandlers[proxyID] = handlerFn
|
||||
p.mu.Unlock()
|
||||
js.Global().Set(fmt.Sprintf("handleVNCWebSocket_%s", proxyID), handlerFn)
|
||||
|
||||
log.Infof("created VNC proxy: %s -> %s (mode=%s, user=%s)", proxyURL, address, mode, username)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
dest, ok := p.destinations[proxyID]
|
||||
handlerFn := p.pendingHandlers[proxyID]
|
||||
delete(p.pendingHandlers, proxyID)
|
||||
p.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
log.Errorf("no destination for VNC proxy %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
conn := &vncConnection{
|
||||
id: proxyID,
|
||||
destination: dest,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
wsHandlerFn: handlerFn,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
go p.connectToVNC(conn)
|
||||
|
||||
log.Infof("VNC proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) setupWebSocketHandlers(ws js.Value, conn *vncConnection) {
|
||||
conn.onMessageFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoMessage", conn.onMessageFn)
|
||||
|
||||
conn.onCloseFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
log.Debug("VNC WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoClose", conn.onCloseFn)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketMessage(conn *vncConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
buf := make([]byte, length)
|
||||
js.CopyBytesToGo(buf, data)
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := vncConn.Write(buf); err != nil {
|
||||
log.Debugf("write to VNC server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) connectToVNC(conn *vncConnection) {
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, vncDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
vncConn, err := p.nbClient.Dial(ctx, "tcp", conn.destination.address)
|
||||
if err != nil {
|
||||
log.Errorf("VNC connect to %s: %v", conn.destination.address, err)
|
||||
// Close the WebSocket so noVNC fires a disconnect event.
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", 1006, fmt.Sprintf("connect to peer: %v", err))
|
||||
}
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.mu.Lock()
|
||||
conn.vncConn = vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
// Send the NetBird VNC session header before the RFB handshake.
|
||||
if err := p.sendSessionHeader(vncConn, conn.destination); err != nil {
|
||||
log.Errorf("send VNC session header: %v", err)
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
// WS→TCP is handled by the onGoMessage handler set in setupWebSocketHandlers,
|
||||
// which writes directly to the VNC connection as data arrives from JS.
|
||||
// Only the TCP→WS direction needs a read loop here.
|
||||
go p.forwardConnToWS(conn)
|
||||
|
||||
<-conn.ctx.Done()
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
// sendSessionHeader writes mode, username, and JWT to the VNC server.
|
||||
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
|
||||
//
|
||||
// [jwt_len: 2 bytes BE] [jwt: N bytes]
|
||||
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
|
||||
usernameBytes := []byte(dest.username)
|
||||
jwtBytes := []byte(dest.jwt)
|
||||
// Format: [mode:1] [username_len:2] [username:N] [jwt_len:2] [jwt:N] [session_id:4]
|
||||
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes)+4)
|
||||
hdr[0] = dest.mode
|
||||
hdr[1] = byte(len(usernameBytes) >> 8)
|
||||
hdr[2] = byte(len(usernameBytes))
|
||||
off := 3
|
||||
copy(hdr[off:], usernameBytes)
|
||||
off += len(usernameBytes)
|
||||
hdr[off] = byte(len(jwtBytes) >> 8)
|
||||
hdr[off+1] = byte(len(jwtBytes))
|
||||
off += 2
|
||||
copy(hdr[off:], jwtBytes)
|
||||
off += len(jwtBytes)
|
||||
hdr[off] = byte(dest.sessionID >> 24)
|
||||
hdr[off+1] = byte(dest.sessionID >> 16)
|
||||
hdr[off+2] = byte(dest.sessionID >> 8)
|
||||
hdr[off+3] = byte(dest.sessionID)
|
||||
|
||||
_, err := conn.Write(hdr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *VNCProxy) forwardConnToWS(conn *vncConnection) {
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Set a read deadline so we detect dead connections instead of
|
||||
// blocking forever when the remote peer dies.
|
||||
conn.mu.Lock()
|
||||
vc := conn.vncConn
|
||||
conn.mu.Unlock()
|
||||
if vc == nil {
|
||||
return
|
||||
}
|
||||
vc.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
n, err := vc.Read(buf)
|
||||
if err != nil {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
// Read timeout: connection might be stale. Send a ping-like
|
||||
// empty read to check. If the connection is truly dead, the
|
||||
// next iteration will fail too and we'll close.
|
||||
continue
|
||||
}
|
||||
if err != io.EOF {
|
||||
log.Debugf("read from VNC connection: %v", err)
|
||||
}
|
||||
// Close the WebSocket to notify noVNC.
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", 1006, "VNC connection lost")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) sendToWebSocket(conn *vncConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) cleanupConnection(conn *vncConnection) {
|
||||
log.Debugf("cleaning up VNC connection %s", conn.id)
|
||||
conn.cancel()
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.vncConn = nil
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn != nil {
|
||||
if err := vncConn.Close(); err != nil {
|
||||
log.Debugf("close VNC connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the global JS handler registered in CreateProxy.
|
||||
globalName := fmt.Sprintf("handleVNCWebSocket_%s", conn.id)
|
||||
js.Global().Delete(globalName)
|
||||
|
||||
// Release all js.Func handles; js.FuncOf pins the Go closure and the
|
||||
// allocations it captures until Release is called.
|
||||
conn.wsHandlerFn.Release()
|
||||
conn.onMessageFn.Release()
|
||||
conn.onCloseFn.Release()
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
delete(p.pendingHandlers, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
@@ -119,6 +119,8 @@ server:
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
|
||||
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
|
||||
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
|
||||
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
|
||||
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
|
||||
|
||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||
// can outlive the test and call t.Logf after teardown, panicking.
|
||||
receiveDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Receive goroutine did not exit after Close")
|
||||
}
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(receiveDone)
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user