Compare commits

..

49 Commits

Author SHA1 Message Date
Viktor Liu
dfe1bba287 Add embedded VNC server with JWT auth, DXGI capture, and dashboard integration 2026-04-14 13:54:31 +02:00
Zoltan Papp
13539543af [client] Fix/grpc retry (#5750)
* [client] Fix flow client Receive retry loop not stopping after Close

Use backoff.Permanent for canceled gRPC errors so Receive returns
immediately instead of retrying until context deadline when the
connection is already closed. Add TestNewClient_PermanentClose to
verify the behavior.

The connectivity.Shutdown check was meaningless because when the connection is
shut down, c.realClient.Events(ctx, grpc.WaitForReady(true)) on the nex line
already fails with codes.Canceled — which is now handled as a permanent error.
The explicit state check was just duplicating what gRPC already reports
through its normal error path.

* [client] remove WaitForReady from stream open call

grpc.WaitForReady(true) parks the RPC call internally until the
connection reaches READY, only unblocking on ctx cancellation.
This means the external backoff.Retry loop in Receive() never gets
control back during a connection outage — it cannot tick, log, or
apply its retry intervals while WaitForReady is blocking.

Removing it restores fail-fast behaviour: Events() returns immediately
with codes.Unavailable when the connection is not ready, which is
exactly what the backoff loop expects. The backoff becomes the single
authority over retry timing and cadence, as originally intended.

* [client] Add connection recreation and improve flow client error handling

Store gRPC dial options on the client to enable connection recreation
on Internal errors (RST_STREAM/PROTOCOL_ERROR). Treat Unauthenticated,
PermissionDenied, and Unimplemented as permanent failures. Unify mutex
usage and add reconnection logging for better observability.

* [client] Remove Unauthenticated, PermissionDenied, and Unimplemented from permanent error handling

* [client] Fix error handling in Receive to properly re-establish stream and improve reconnection messaging

* Fix test

* [client] Add graceful shutdown handling and test for concurrent Close during Receive

Prevent reconnection attempts after client closure by tracking a `closed` flag. Use `backoff.Permanent` for errors caused by operations on a closed client. Add a test to ensure `Close` does not block when `Receive` is actively running.

* [client] Fix connection swap to properly close old gRPC connection

Close the old `gRPC.ClientConn` after successfully swapping to a new connection during reconnection.

* [client] Reset backoff

* [client] Ensure stream closure on error during initialization

* [client] Add test for handling server-side stream closure and reconnection

Introduce `TestReceive_ServerClosesStream` to verify the client's ability to recover and process acknowledgments after the server closes the stream. Enhance test server with a controlled stream closure mechanism.

* [client] Add protocol error simulation and enhance reconnection test

Introduce `connTrackListener` to simulate HTTP/2 RST_STREAM with PROTOCOL_ERROR for testing. Refactor and rename `TestReceive_ServerClosesStream` to `TestReceive_ProtocolErrorStreamReconnect` to verify client recovery on protocol errors.

* [client] Update Close error message in test for clarity

* [client] Fine-tune the tests

* [client] Adjust connection tracking in reconnection test

* [client] Wait for Events handler to exit in RST_STREAM reconnection test

Ensure the old `Events` handler exits fully before proceeding in the reconnection test to avoid dropped acknowledgments on a broken stream. Add a `handlerDone` channel to synchronize handler exits.

* [client] Prevent panic on nil connection during Close

* [client] Refactor connection handling to use explicit target tracking

Introduce `target` field to store the gRPC connection target directly, simplifying reconnections and ensuring consistent connection reuse logic.

* [client] Rename `isCancellation` to `isContextDone` and extend handling for `DeadlineExceeded`

Refactor error handling to include `DeadlineExceeded` scenarios alongside `Canceled`. Update related condition checks for consistency.

* [client] Add connection generation tracking to prevent stale reconnections

Introduce `connGen` to track connection generations and ensure that stale `recreateConnection` calls do not override newer connections. Update stream establishment and reconnection logic to incorporate generation validation.

* [client] Add backoff reset condition to prevent short-lived retry cycles

Refine backoff reset logic to ensure it only occurs for sufficiently long-lived stream connections, avoiding interference with `MaxElapsedTime`.

* [client] Introduce `minHealthyDuration` to refine backoff reset logic

Add `minHealthyDuration` constant to ensure stream retries only reset the backoff timer if the stream survives beyond a minimum duration. Prevents unhealthy, short-lived streams from interfering with `MaxElapsedTime`.

* [client] IPv6 friendly connection

parsedURL.Hostname() strips IPv6 brackets. For http://[::1]:443, this turns it into ::1:443, which is not a valid host:port target for gRPC. Additionally, fmt.Sprintf("%s:%s", hostname, port) produces a trailing colon when the URL has no explicit port—http://example.com becomes example.com:. Both cases break the initial dial and reconnect paths. Use parsedURL.Host directly instead.

* [client] Add `handlerStarted` channel to synchronize stream establishment in tests

Introduce `handlerStarted` channel in the test server to signal when the server-side handler begins, ensuring robust synchronization between client and server during stream establishment. Update relevant test cases to wait for this signal before proceeding.

* [client] Replace `receivedAcks` map with atomic counter and improve stream establishment sync in tests

Refactor acknowledgment tracking in tests to use an `atomic.Int32` counter instead of a map. Replace fixed sleep with robust synchronization by waiting on `handlerStarted` signal for stream establishment.

* [client] Extract `handleReceiveError` to simplify receive logic

Refactor error handling in `receive` to a dedicated `handleReceiveError` method. Streamlines the main logic and isolates error recovery, including backoff reset and connection recreation.

* [client] recreate gRPC ClientConn on every retry to prevent dual backoff

The flow client had two competing retry loops: our custom exponential
backoff and gRPC's internal subchannel reconnection. When establishStream
failed, the same ClientConn was reused, allowing gRPC's internal backoff
state to accumulate and control dial timing independently.

Changes:
- Consolidate error handling into handleRetryableError, which now
 handles context cancellation, permanent errors, backoff reset,
 and connection recreation in a single path
- Call recreateConnection on every retryable error so each retry
 gets a fresh ClientConn with no internal backoff state
- Remove connGen tracking since Receive is sequential and protected
 by a new receiving guard against concurrent calls
- Reduce RandomizationFactor from 1 to 0.5 to avoid near-zero
 backoff intervals
2026-04-13 10:42:24 +02:00
Zoltan Papp
7483fec048 Fix Android internet blackhole caused by stale route re-injection on TUN rebuild (#5865)
extraInitialRoutes() was meant to preserve only the fake IP route
(240.0.0.0/8) across TUN rebuilds, but it re-injected any initial
route missing from the current set. When the management server
advertised exit node routes (0.0.0.0/0) that were later filtered
by the route selector, extraInitialRoutes() re-added them, causing
the Android VPN to capture all traffic with no peer to handle it.

Store the fake IP route explicitly and append only that in notify(),
removing the overly broad initial route diffing.
2026-04-13 09:38:38 +02:00
Pascal Fischer
5259e5df51 [management] add domain and service cleanup migration (#5850) 2026-04-11 12:00:40 +02:00
Zoltan Papp
ebd78e0122 [client] Update RaceDial to accept context for improved cancellation handling (#5849) 2026-04-10 20:51:04 +02:00
Pascal Fischer
cf86b9a528 [management] enable access log cleanup by default (#5842) 2026-04-10 17:07:27 +02:00
Pascal Fischer
ee588e1536 Revert "[management] allow local routing peer resource (#5814)" (#5847) 2026-04-10 14:53:47 +02:00
Pascal Fischer
2a8aacc5c9 [management] allow local routing peer resource (#5814) 2026-04-10 13:08:21 +02:00
Pascal Fischer
15709bc666 [management] update account delete with proper proxy domain and service cleanup (#5817) 2026-04-10 13:08:04 +02:00
Pascal Fischer
789b4113fe [misc] update dashboards (#5840) 2026-04-10 12:15:58 +02:00
Viktor Liu
d2cdc0efec [client] Use native firewall for peer ACLs in userspace WireGuard mode (#5668) 2026-04-10 09:12:13 +08:00
Pascal Fischer
ee343d5d77 [management] use sql null vars (#5844) 2026-04-09 18:12:38 +02:00
Maycon Santos
099c493b18 [management] network map tests (#5795)
* Add network map benchmark and correctness test files

* Add tests for network map components correctness and edge cases

* Skip benchmarks in CI and enhance network map test coverage with new helper functions

* Remove legacy network map benchmarks and tests; refactor components-based test coverage for clarity and scalability.
2026-04-08 21:28:29 +02:00
Pascal Fischer
c1d1229ae0 [management] use NullBool for terminated flag (#5829) 2026-04-08 21:08:43 +02:00
Viktor Liu
94a36cb53e [client] Handle UPnP routers that only support permanent leases (#5826) 2026-04-08 17:59:59 +02:00
Viktor Liu
c7ba931466 [client] Populate network addresses in FreeBSD system info (#5827) 2026-04-08 17:14:16 +02:00
Viktor Liu
413d95b740 [client] Include service.json in debug bundle (#5825)
* Include service.json in debug bundle

* Add tests for service params sanitization logic
2026-04-08 21:10:31 +08:00
Viktor Liu
332c624c55 [client] Don't abort UI debug bundle when up/down fails (#5780) 2026-04-08 10:33:46 +02:00
Viktor Liu
dc160aff36 [client] Fix SSH proxy stripping shell quoting from forwarded commands (#5669) 2026-04-08 10:25:57 +02:00
Zoltan Papp
96806bf55f [relay] Replace net.Conn with context-aware Conn interface (#5770)
* [relay] Replace net.Conn with context-aware Conn interface for relay transports

Introduce a listener.Conn interface with context-based Read/Write methods,
replacing net.Conn throughout the relay server. This enables proper timeout
propagation (e.g. handshake timeout) without goroutine-based workarounds
and removes unused LocalAddr/SetDeadline methods from WS and QUIC conns.

* [relay] Refactor Peer context management to ensure proper cleanup

Integrate context creation (`context.WithCancel`) directly in `NewPeer` and remove redundant initialization in `Work`. Add `ctxCancel` calls to ensure context is properly canceled during `Close` operations.
2026-04-08 09:38:31 +02:00
Viktor Liu
d33cd4c95b [client] Add NAT-PMP/UPnP support (#5202) 2026-04-08 15:29:32 +08:00
Maycon Santos
e2c2f64be7 [client] Fix iOS DNS upstream routing for deselected exit nodes (#5803)
- Add GetSelectedClientRoutes() to the route manager that filters through FilterSelectedExitNodes, returning only active routes instead of all management routes              
  - Use GetSelectedClientRoutes() in the DNS route checker so deselected exit nodes' 0.0.0.0/0 no longer matches upstream DNS IPs — this prevented the resolver from switching
  away from the utun-bound socket after exit node deselection                                                                                                                   
  - Initialize iOS DNS server with host DNS fallback addresses (1.1.1.1:53, 1.0.0.1:53) and a permanent root zone handler, matching Android's behavior — without this, unmatched
   DNS queries arriving via the 0.0.0.0/0 tunnel route had no handler and were silently dropped
2026-04-08 08:43:48 +02:00
Viktor Liu
cb73b94ffb [client] Add TCP DNS support for local listener (#5758) 2026-04-08 07:40:36 +02:00
Viktor Liu
1d920d700c [client] Fix SSH server Stop() deadlock when sessions are active (#5717) 2026-04-07 17:56:54 +02:00
Viktor Liu
bb85eee40a [client] Skip down interfaces in network address collection for posture checks (#5768) 2026-04-07 17:56:48 +02:00
Viktor Liu
aba5d6f0d2 [client] Error out on netbird expose when block inbound is enabled (#5818) 2026-04-07 17:55:35 +02:00
Viktor Liu
0588d2dbe1 [management] Load missing service columns in pgx account loader (#5816) 2026-04-07 14:56:56 +02:00
Pascal Fischer
14b3b77bda [management] validate permissions on groups read with name (#5749) 2026-04-07 14:13:09 +02:00
Zoltan Papp
6da34e483c [client] Fix mgmProber interface to match unexported GetServerPublicKey (#5815)
Update the mgmProber interface to use HealthCheck() instead of the
now-unexported GetServerPublicKey(), aligning with the changes in the
management client API.
2026-04-07 13:13:38 +02:00
Zoltan Papp
0efef671d7 [client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
2026-04-07 12:18:21 +02:00
Eduard Gert
435203b13b [proxy] Update proxy web packages (#5661)
* [proxy] Update package-lock.json

* Update packages
2026-04-07 10:35:09 +02:00
Maycon Santos
decb5dd3af [client] Add GetSelectedClientRoutes to route manager and update DNS route check (#5802)
- DNS resolution broke after deselecting an exit node because the route checker used all client routes (including deselected ones) to decide how to forward upstream DNS
  queries
  - Added GetSelectedClientRoutes() to the route manager that filters out deselected exit nodes, and switched the DNS route checker to use it
  - Confirmed fix via device testing: after deselecting exit node, DNS queries now correctly use a regular network socket instead of binding to the utun interface
2026-04-05 13:44:53 +02:00
Viktor Liu
28fbf96b2a [client] Fix flaky TestServiceLifecycle/Restart on FreeBSD (#5786) 2026-04-02 21:45:49 +02:00
Bethuel Mmbaga
9d1a37c644 [management,client] Revert gRPC client secret removal (#5781)
* This reverts commit e5914e4e8b

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Deprecate client secret in proto

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2026-04-02 18:21:00 +02:00
Viktor Liu
5bf2372c4d [management] Fix L4 service creation deadlock on single-connection databases (#5779) 2026-04-02 14:46:14 +02:00
Bethuel Mmbaga
c2c6396a04 [management] Allow updating embedded IdP user name and email (#5721) 2026-04-02 13:02:10 +03:00
Misha Bragin
aaf813fc0c Add selfhosted scaling note (#5769) 2026-04-01 19:23:39 +02:00
Vlad
d97fe84296 [management] fix race condition in the setup flow that enables creation of multiple owner users (#5754) 2026-04-01 16:25:35 +02:00
tham-le
81f45dab21 [client] Support embed.Client on Android with netstack mode (#5623)
* [client] Support embed.Client on Android with netstack mode

embed.Client.Start() calls ConnectClient.Run() which passes an empty
MobileDependency{}. On Android, the engine dereferences nil fields
(IFaceDiscover, NetworkChangeListener, DnsReadyListener) causing panics.

Provide complete no-op stubs so the engine's existing Android code
paths work unchanged — zero modifications to engine.go:

- Add androidRunOverride hook in Run() for Android-specific dispatch
- Add runOnAndroidEmbed() with complete MobileDependency (all stubs)
- Wire default stubs via init() in connect_android_default.go:
  noopIFaceDiscover, noopNetworkChangeListener, noopDnsReadyListener
- Forward logPath to c.run()

Tested: embed.Client starts on Android arm64, joins mesh via relay,
discovers peers, localhost proxy works for TCP+UDP forwarding.

* [client] Fix TestServiceParamsPath for Windows path separators

Use filepath.Join in test assertions instead of hardcoded POSIX paths
so the test passes on Windows where filepath.Join uses backslashes.
2026-04-01 16:19:34 +02:00
Zoltan Papp
d670e7382a [client] Fix ipv6 address in quic server (#5763)
* [client] Use `net.JoinHostPort` for consistency in constructing host-port pairs

* [client] Fix handling of IPv6 addresses by trimming brackets in `net.JoinHostPort`
2026-04-01 15:11:23 +02:00
Pascal Fischer
cd8c686339 [misc] add path traversal and file size protections (#5755) 2026-04-01 14:23:24 +02:00
Pascal Fischer
f5c41e3018 [misc] set permissions on env file for getting started scripts (#5761) 2026-04-01 14:13:53 +02:00
Pascal Fischer
2477f99d89 [proxy] Add pprof (#5764) 2026-04-01 14:10:41 +02:00
shuuri-labs
940f530ac2 [management] Legacy to embedded IdP migration tool (#5586) 2026-04-01 13:53:19 +02:00
Zoltan Papp
4d3e2f8ad3 Fix path join (#5762) 2026-04-01 13:21:19 +02:00
Vlad
5ae986e1c4 [management] fix panic on management reboot (#5759) 2026-04-01 12:31:30 +02:00
Bethuel Mmbaga
e5914e4e8b [management,client] Remove client secret from gRPC auth flow (#5751)
Remove client secret from gRPC auth flow. The secret was originally included to support providers like Google Workspace that don't offer a proper PKCE flow, but this is no longer necessary with the embedded IdP. Deployments using such providers should migrate to the embedded IdP instead.
2026-03-31 18:50:49 +03:00
Pascal Fischer
c238f5425f [management] proper module permission validation for posture check delete (#5742) 2026-03-31 16:43:49 +02:00
Pascal Fischer
3c3097ea74 [management] add target user account validation (#5741) 2026-03-31 16:43:16 +02:00
225 changed files with 20014 additions and 3869 deletions

View File

@@ -31,7 +31,7 @@ jobs:
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
@@ -88,7 +88,7 @@ jobs:
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\)" | head -1)
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -154,6 +154,26 @@ builds:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-idp-migrate
dir: tools/idp-migrate
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-idp-migrate
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
@@ -166,6 +186,10 @@ archives:
- netbird-wasm
name_template: "{{ .ProjectName }}_{{ .Version }}"
format: binary
- id: netbird-idp-migrate
builds:
- netbird-idp-migrate
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
nfpms:
- maintainer: Netbird <dev@netbird.io>

View File

@@ -199,9 +199,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level set to trace.")
}
needsRestoreUp := false
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = !stateWasDown
cmd.Println("netbird down")
}
@@ -217,6 +219,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
} else {
needsRestoreUp = false
cmd.Println("netbird up")
}
@@ -264,6 +267,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
if needsRestoreUp {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service up state: %v\n", status.Convert(err).Message())
} else {
cmd.Println("netbird up (restored)")
}
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())

View File

@@ -14,6 +14,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/proto"
@@ -201,7 +202,7 @@ func exposeFn(cmd *cobra.Command, args []string) error {
stream, err := client.ExposeService(ctx, req)
if err != nil {
return fmt.Errorf("expose service: %w", err)
return fmt.Errorf("expose service: %v", status.Convert(err).Message())
}
if err := handleExposeReady(cmd, stream, port); err != nil {
@@ -236,7 +237,7 @@ func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
event, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive expose event: %w", err)
return fmt.Errorf("receive expose event: %v", status.Convert(err).Message())
}
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)

View File

@@ -150,6 +150,7 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(vncCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -25,10 +25,10 @@ func TestServiceParamsPath(t *testing.T) {
t.Cleanup(func() { configs.StateDir = original })
configs.StateDir = "/var/lib/netbird"
assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath())
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
configs.StateDir = "/custom/state"
assert.Equal(t, "/custom/state/service.json", serviceParamsPath())
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
}
func TestSaveAndLoadServiceParams(t *testing.T) {

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
@@ -13,6 +15,22 @@ import (
"github.com/stretchr/testify/require"
)
// TestMain intercepts when this test binary is run as a daemon subprocess.
// On FreeBSD, the rc.d service script runs the binary via daemon(8) -r with
// "service run ..." arguments. Since the test binary can't handle cobra CLI
// args, it exits immediately, causing daemon -r to respawn rapidly until
// hitting the rate limit and exiting. This makes service restart unreliable.
// Blocking here keeps the subprocess alive until the init system sends SIGTERM.
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "service" && os.Args[2] == "run" {
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, os.Interrupt)
<-sig
return
}
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
@@ -79,6 +97,34 @@ func TestServiceLifecycle(t *testing.T) {
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {

View File

@@ -36,7 +36,10 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
jwtCacheTTLFlag = "jwt-cache-ttl"
// Alias for backward compatibility.
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var (
@@ -61,7 +64,7 @@ var (
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
jwtCacheTTL int
)
func init() {
@@ -71,7 +74,9 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)

View File

@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -371,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
if cmd.Flag(disableVNCAuthFlag).Changed {
req.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
req.SshJWTCacheTTL = &jwtCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -458,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -479,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
if cmd.Flag(disableVNCAuthFlag).Changed {
ic.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &jwtCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
@@ -582,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
@@ -603,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
if cmd.Flag(disableVNCAuthFlag).Changed {
loginRequest.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {

271
client/cmd/vnc.go Normal file
View File

@@ -0,0 +1,271 @@
package cmd
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
vncUsername string
vncHost string
vncMode string
vncListen string
vncNoBrowser bool
vncNoCache bool
)
func init() {
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
}
var vncCmd = &cobra.Command{
Use: "vnc [flags] [user@]host",
Short: "Connect to a NetBird peer via VNC",
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
The target peer must have the VNC server enabled.
Two modes are available:
- attach: view the current physical display (remote support)
- session: start a virtual desktop as the specified user (passwordless login)
Use --listen to start a local proxy for external VNC viewers:
netbird vnc --listen :5900 peer-hostname
vncviewer localhost:5900
Examples:
netbird vnc peer-hostname
netbird vnc --mode session --user alice peer-hostname
netbird vnc --listen :5900 peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: vncFn,
}
func vncFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
if err := parseVNCHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
vncCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runVNC(vncCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-vncCtx.Done()
return nil
case err := <-errCh:
return err
case <-vncCtx.Done():
}
return nil
}
func parseVNCHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return fmt.Errorf("invalid user@host format")
}
if vncUsername == "" {
vncUsername = parts[0]
}
vncHost = parts[1]
if vncMode == "attach" {
vncMode = "session"
}
} else {
vncHost = arg
}
if vncMode == "session" && vncUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
vncUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
vncUsername = currentUser.Username
}
}
return nil
}
func runVNC(ctx context.Context, cmd *cobra.Command) error {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
if vncMode == "session" {
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
} else {
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
}
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
var jwtToken string
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !vncNoBrowser {
browserOpener = util.OpenBrowser
}
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
if err != nil {
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
} else {
jwtToken = token
log.Debug("JWT authentication successful")
}
// Connect to the VNC server on the standard port (5900). The peer's firewall
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
vncAddr := net.JoinHostPort(vncHost, "5900")
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
if err != nil {
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
}
defer vncConn.Close()
// Send session header with mode, username, and JWT.
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
return fmt.Errorf("send VNC header: %w", err)
}
cmd.Printf("VNC connected to %s\n", vncHost)
if vncListen != "" {
return runVNCLocalProxy(ctx, cmd, vncConn)
}
// No --listen flag: inform the user they need to use --listen for external viewers.
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
cmd.Printf("Press Ctrl+C to disconnect.\n")
<-ctx.Done()
return nil
}
const vncDialTimeout = 15 * time.Second
// sendVNCHeader writes the NetBird VNC session header.
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
var modeByte byte
if mode == "session" {
modeByte = 1
}
usernameBytes := []byte(username)
jwtBytes := []byte(jwt)
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
hdr[0] = modeByte
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
off += 2
copy(hdr[off:], jwtBytes)
_, err := conn.Write(hdr)
return err
}
// runVNCLocalProxy listens on the given address and proxies incoming
// connections to the already-established VNC tunnel.
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
listener, err := net.Listen("tcp", vncListen)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncListen, err)
}
defer listener.Close()
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
cmd.Printf("Press Ctrl+C to stop.\n")
go func() {
<-ctx.Done()
listener.Close()
}()
// Accept a single viewer connection. VNC is single-session: the RFB
// handshake completes on vncConn for the first viewer, so subsequent
// viewers would get a mid-stream connection. The loop handles transient
// accept errors until a valid connection arrives.
for {
clientConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
}
log.Debugf("accept VNC proxy client: %v", err)
continue
}
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
// Bidirectional copy.
done := make(chan struct{})
go func() {
io.Copy(vncConn, clientConn)
close(done)
}()
io.Copy(clientConn, vncConn)
<-done
clientConn.Close()
cmd.Printf("VNC viewer disconnected\n")
return nil
}
}

62
client/cmd/vnc_agent.go Normal file
View File

@@ -0,0 +1,62 @@
//go:build windows
package cmd
import (
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var vncAgentPort string
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server in the current user session, listening on
// localhost. It is spawned by the NetBird service (Session 0) via
// CreateProcessAsUser into the interactive console session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Agent's stderr is piped to the service which relogs it.
// Use JSON format with caller info for structured parsing.
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
capturer := vncserver.NewDesktopCapturer()
injector := vncserver.NewWindowsInputInjector()
srv := vncserver.New(capturer, injector, "")
// Auth is handled by the service. The agent verifies a token on each
// connection to ensure only the service process can connect.
// The token is passed via environment variable to avoid exposing it
// in the process command line (visible via tasklist/wmic).
srv.SetDisableAuth(true)
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
if err != nil {
return err
}
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
return err
}
<-cmd.Context().Done()
return srv.Stop()
},
}

16
client/cmd/vnc_flags.go Normal file
View File

@@ -0,0 +1,16 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCAuthFlag = "disable-vnc-auth"
)
var (
serverVNCAllowed bool
disableVNCAuth bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"os"
"strconv"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
@@ -35,20 +36,27 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
log.Info("forcing userspace firewall")
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
// Kernel cannot fall back to anything else, need to return error
if !iface.IsUserspaceBind() {
return fm, err
}
// Fall back to the userspace packet filter if native is unavailable
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
return fm, nil
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
@@ -160,3 +168,17 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
_, err := client.ListChains("filter")
return err == nil
}
func forceUserspaceFirewall() bool {
val := os.Getenv(EnvForceUserspaceFirewall)
if val == "" {
return false
}
force, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
return false
}
return force
}

View File

@@ -7,6 +7,12 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// EnvForceUserspaceFirewall forces the use of the userspace packet filter even when
// native iptables/nftables is available. This only applies when the WireGuard interface
// runs in userspace mode. When set, peer ACLs are handled by USPFilter instead of
// kernel netfilter rules.
const EnvForceUserspaceFirewall = "NB_FORCE_USERSPACE_FIREWALL"
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string

View File

@@ -33,7 +33,6 @@ type Manager struct {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Create iptables firewall manager
@@ -64,10 +63,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)
@@ -203,12 +201,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nberrors.FormatErrorOrNil(merr)
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
_, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0},
@@ -286,6 +282,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRaw = "NETBIRD-RAW"
chainOUTPUT = "OUTPUT"

View File

@@ -47,8 +47,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)

View File

@@ -36,6 +36,7 @@ const (
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
@@ -43,6 +44,7 @@ const (
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
@@ -387,6 +389,14 @@ func (r *router) cleanUpDefaultForwardRules() error {
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, "OUTPUT", jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
for _, chainInfo := range []struct {
chain string
table string
@@ -396,6 +406,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
@@ -970,6 +981,81 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, "OUTPUT", 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil

View File

@@ -9,10 +9,9 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -23,10 +22,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex

View File

@@ -169,6 +169,14 @@ type Manager interface {
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
AddOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
// localAddr must be IPv4; the underlying iptables/nftables backends are IPv4-only.
RemoveOutputDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error

View File

@@ -40,7 +40,6 @@ func getTableName() string {
type iFaceMapper interface {
Name() string
Address() wgaddr.Address
IsUserspaceBind() bool
}
// Manager of iptables firewall
@@ -106,10 +105,9 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -205,12 +203,10 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic
// AllowNetbird allows netbird interface traffic.
// This is called when USPFilter wraps the native firewall, adding blanket accept
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
if !m.wgIface.IsUserspaceBind() {
return nil
}
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -346,6 +342,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
const (
chainNameRawOutput = "netbird-raw-out"
chainNameRawPrerouting = "netbird-raw-pre"

View File

@@ -52,8 +52,6 @@ func (i *iFaceMock) Address() wgaddr.Address {
panic("AddressFunc is not set")
}
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface

View File

@@ -36,6 +36,7 @@ const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
@@ -1853,6 +1854,130 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *router) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *router) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *router) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,

View File

@@ -8,10 +8,9 @@ import (
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -22,10 +21,6 @@ func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}

View File

@@ -140,6 +140,17 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
manager.SetUDPPacketHook(
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -144,6 +144,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
// TODO: filter out down interfaces (net.FlagUp). Also handle the reverse
// case where an interface comes up between refreshes.
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}

View File

@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,9 +18,7 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
udpHook func([]byte) bool
drop bool
}
// ID returns the rule id

View File

@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: false,

View File

@@ -15,14 +15,17 @@ type PacketFilter interface {
// FilterInbound filter incoming packets from external sources to host
FilterInbound(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// SetUDPPacketHook registers a hook for outbound UDP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one UDP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
// RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error
// SetTCPPacketHook registers a hook for outbound TCP packets matching the given IP and port.
// Hook function returns true if the packet should be dropped.
// Only one TCP hook is supported; calling again replaces the previous hook.
// Pass nil hook to remove.
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
}
// FilteredDevice to override Read or Write of packets

View File

@@ -34,18 +34,28 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
// SetUDPPacketHook mocks base method.
func (m *MockPacketFilter) SetUDPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string)
return ret0
m.ctrl.Call(m, "SetUDPPacketHook", arg0, arg1, arg2)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// SetUDPPacketHook indicates an expected call of SetUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetUDPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetUDPPacketHook), arg0, arg1, arg2)
}
// SetTCPPacketHook mocks base method.
func (m *MockPacketFilter) SetTCPPacketHook(arg0 netip.Addr, arg1 uint16, arg2 func([]byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetTCPPacketHook", arg0, arg1, arg2)
}
// SetTCPPacketHook indicates an expected call of SetTCPPacketHook.
func (mr *MockPacketFilterMockRecorder) SetTCPPacketHook(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTCPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).SetTCPPacketHook), arg0, arg1, arg2)
}
// FilterInbound mocks base method.
@@ -75,17 +85,3 @@ func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 an
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
}
// RemovePacketHook mocks base method.
func (m *MockPacketFilter) RemovePacketHook(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemovePacketHook", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// RemovePacketHook indicates an expected call of RemovePacketHook.
func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
}

View File

@@ -1,87 +0,0 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
import (
net "net"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockPacketFilter is a mock of PacketFilter interface.
type MockPacketFilter struct {
ctrl *gomock.Controller
recorder *MockPacketFilterMockRecorder
}
// MockPacketFilterMockRecorder is the mock recorder for MockPacketFilter.
type MockPacketFilterMockRecorder struct {
mock *MockPacketFilter
}
// NewMockPacketFilter creates a new mock instance.
func NewMockPacketFilter(ctrl *gomock.Controller) *MockPacketFilter {
mock := &MockPacketFilter{ctrl: ctrl}
mock.recorder = &MockPacketFilterMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
return m.recorder
}
// AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func(*net.UDPAddr, []byte) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
}
// AddUDPPacketHook indicates an expected call of AddUDPPacketHook.
func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
}
// FilterInbound mocks base method.
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
}
// FilterOutbound mocks base method.
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
}
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -19,6 +19,9 @@ import (
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
@@ -135,6 +138,7 @@ func TestDefaultManager(t *testing.T) {
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
@@ -194,6 +198,7 @@ func TestDefaultManagerStateless(t *testing.T) {
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
@@ -258,6 +263,7 @@ func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
@@ -339,6 +345,7 @@ func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()

View File

@@ -155,7 +155,7 @@ func (a *Auth) IsLoginRequired(ctx context.Context) (bool, error) {
var needsLogin bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
_, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isLoginNeeded(err) {
needsLogin = true
return nil
@@ -179,8 +179,8 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
var isAuthError bool
err = a.withRetry(ctx, func(client *mgm.GrpcClient) error {
serverKey, _, err := a.doMgmLogin(client, ctx, pubSSHKey)
if serverKey != nil && isRegistrationNeeded(err) {
err := a.doMgmLogin(client, ctx, pubSSHKey)
if isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = a.registerPeer(client, ctx, setupKey, jwtToken, pubSSHKey)
if err != nil {
@@ -201,13 +201,7 @@ func (a *Auth) Login(ctx context.Context, setupKey string, jwtToken string) (err
// getPKCEFlow retrieves PKCE authorization flow configuration and creates a flow instance
func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetPKCEAuthorizationFlow(*serverKey)
protoFlow, err := client.GetPKCEAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find pkce flow, contact admin: %v", err)
@@ -221,7 +215,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
config := &PKCEAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
TokenEndpoint: protoConfig.GetTokenEndpoint(),
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
Scope: protoConfig.GetScope(),
@@ -246,13 +240,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
// getDeviceFlow retrieves device authorization flow configuration and creates a flow instance
func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
protoFlow, err := client.GetDeviceAuthorizationFlow(*serverKey)
protoFlow, err := client.GetDeviceAuthorizationFlow()
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
log.Warnf("server couldn't find device flow, contact admin: %v", err)
@@ -266,7 +254,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
config := &DeviceAuthProviderConfig{
Audience: protoConfig.GetAudience(),
ClientID: protoConfig.GetClientID(),
ClientSecret: protoConfig.GetClientSecret(),
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
Domain: protoConfig.Domain,
TokenEndpoint: protoConfig.GetTokenEndpoint(),
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
@@ -292,28 +280,16 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
}
// doMgmLogin performs the actual login operation with the management service
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) (*wgtypes.Key, *mgmProto.LoginResponse, error) {
serverKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, nil, err
}
func (a *Auth) doMgmLogin(client *mgm.GrpcClient, ctx context.Context, pubSSHKey []byte) error {
sysInfo := system.GetInfo(ctx)
a.setSystemInfoFlags(sysInfo)
loginResp, err := client.Login(*serverKey, sysInfo, pubSSHKey, a.config.DNSLabels)
return serverKey, loginResp, err
_, err := client.Login(sysInfo, pubSSHKey, a.config.DNSLabels)
return err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, err
}
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -322,7 +298,7 @@ func (a *Auth) registerPeer(client *mgm.GrpcClient, ctx context.Context, setupKe
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
a.setSystemInfoFlags(info)
loginResp, err := client.Register(*serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
loginResp, err := client.Register(validSetupKey.String(), jwtToken, info, pubSSHKey, a.config.DNSLabels)
if err != nil {
log.Errorf("failed registering peer %v", err)
return nil, err
@@ -339,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
@@ -351,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
a.config.DisableVNCAuth,
)
}

View File

@@ -44,6 +44,10 @@ import (
"github.com/netbirdio/netbird/version"
)
// androidRunOverride is set on Android to inject mobile dependencies
// when using embed.Client (which calls Run() with empty MobileDependency).
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
type ConnectClient struct {
ctx context.Context
config *profilemanager.Config
@@ -76,6 +80,9 @@ func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
// Run with main logic.
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
if androidRunOverride != nil {
return androidRunOverride(c, runningChan, logPath)
}
return c.run(MobileDependency{}, runningChan, logPath)
}
@@ -104,6 +111,7 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -113,6 +121,7 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")
@@ -534,11 +543,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DisableVNCAuth: config.DisableVNCAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
@@ -610,17 +621,12 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *profilemanager.Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "failed while getting Management Service public key: %s", err)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
@@ -633,13 +639,9 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
config.DisableVNCAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
return loginResp, nil
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}
func statusRecorderToMgmConnStateNotifier(statusRecorder *peer.Status) mgm.ConnStateNotifier {

View File

@@ -0,0 +1,73 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
// It returns an empty interface list, which means ICE P2P candidates won't be
// discovered — connections will fall back to relay. Applications that need P2P
// should provide a real implementation via runOnAndroidEmbed that uses
// Android's ConnectivityManager to enumerate network interfaces.
type noopIFaceDiscover struct{}
func (noopIFaceDiscover) IFaces() (string, error) {
// Return empty JSON array — no local interfaces advertised for ICE.
// This is intentional: without Android's ConnectivityManager, we cannot
// reliably enumerate interfaces (netlink is restricted on Android 11+).
// Relay connections still work; only P2P hole-punching is disabled.
return "[]", nil
}
// noopNetworkChangeListener is a stub for embed.Client on Android.
// Network change events are ignored since the embed client manages its own
// reconnection logic via the engine's built-in retry mechanism.
type noopNetworkChangeListener struct{}
func (noopNetworkChangeListener) OnNetworkChanged(string) {
// No-op: embed.Client relies on the engine's internal reconnection
// logic rather than OS-level network change notifications.
}
func (noopNetworkChangeListener) SetInterfaceIP(string) {
// No-op: in netstack mode, the overlay IP is managed by the userspace
// network stack, not by OS-level interface configuration.
}
// noopDnsReadyListener is a stub for embed.Client on Android.
// DNS readiness notifications are not needed in netstack/embed mode
// since system DNS is disabled and DNS resolution happens externally.
type noopDnsReadyListener struct{}
func (noopDnsReadyListener) OnReady() {
// No-op: embed.Client does not need DNS readiness notifications.
// System DNS is disabled in netstack mode.
}
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
var _ dns.ReadyListener = noopDnsReadyListener{}
func init() {
// Wire up the default override so embed.Client.Start() works on Android
// with netstack mode. Provides complete no-op stubs for all mobile
// dependencies so the engine's existing Android code paths work unchanged.
// Applications that need P2P ICE or real DNS should replace this by
// setting androidRunOverride before calling Start().
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
return c.runOnAndroidEmbed(
noopIFaceDiscover{},
noopNetworkChangeListener{},
[]netip.AddrPort{},
noopDnsReadyListener{},
runningChan,
logPath,
)
}
}

View File

@@ -0,0 +1,32 @@
//go:build android
package internal
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
)
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
// so embed.Client.Start() can detect when the engine is ready.
// It provides complete MobileDependency so the engine's existing
// Android code paths work unchanged.
func (c *ConnectClient) runOnAndroidEmbed(
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
runningChan chan struct{},
logPath string,
) error {
mobileDependency := MobileDependency{
IFaceDiscover: iFaceDiscover,
NetworkChangeListener: networkChangeListener,
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, runningChan, logPath)
}

View File

@@ -25,6 +25,7 @@ import (
"google.golang.org/protobuf/encoding/protojson"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/updater/installer"
@@ -52,6 +53,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states for the active profile.
service_params.json: Sanitized service install parameters (service.json). Sensitive environment variable values are masked. Only present when service.json exists.
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information.
@@ -359,6 +361,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
}
if err := g.addServiceParams(); err != nil {
log.Errorf("failed to add service params to debug bundle: %v", err)
}
if err := g.addMetrics(); err != nil {
log.Errorf("failed to add metrics to debug bundle: %v", err)
}
@@ -488,6 +494,90 @@ func (g *BundleGenerator) addConfig() error {
return nil
}
const (
serviceParamsFile = "service.json"
serviceParamsBundle = "service_params.json"
maskedValue = "***"
envVarPrefix = "NB_"
jsonKeyManagementURL = "management_url"
jsonKeyServiceEnv = "service_env_vars"
)
var sensitiveEnvSubstrings = []string{"key", "token", "secret", "password", "credential"}
// addServiceParams reads the service.json file and adds a sanitized version to the bundle.
// Non-NB_ env vars and vars with sensitive names are masked. Other NB_ values are anonymized.
func (g *BundleGenerator) addServiceParams() error {
path := filepath.Join(configs.StateDir, serviceParamsFile)
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("read service params: %w", err)
}
var params map[string]any
if err := json.Unmarshal(data, &params); err != nil {
return fmt.Errorf("parse service params: %w", err)
}
if g.anonymize {
if mgmtURL, ok := params[jsonKeyManagementURL].(string); ok && mgmtURL != "" {
params[jsonKeyManagementURL] = g.anonymizer.AnonymizeURI(mgmtURL)
}
}
g.sanitizeServiceEnvVars(params)
sanitizedData, err := json.MarshalIndent(params, "", " ")
if err != nil {
return fmt.Errorf("marshal sanitized service params: %w", err)
}
if err := g.addFileToZip(bytes.NewReader(sanitizedData), serviceParamsBundle); err != nil {
return fmt.Errorf("add service params to zip: %w", err)
}
return nil
}
// sanitizeServiceEnvVars masks or anonymizes env var values in service params.
// Non-NB_ vars and vars with sensitive names (key, token, etc.) are fully masked.
// Other NB_ var values are passed through the anonymizer when anonymization is enabled.
func (g *BundleGenerator) sanitizeServiceEnvVars(params map[string]any) {
envVars, ok := params[jsonKeyServiceEnv].(map[string]any)
if !ok {
return
}
sanitized := make(map[string]any, len(envVars))
for k, v := range envVars {
val, _ := v.(string)
switch {
case !strings.HasPrefix(k, envVarPrefix) || isSensitiveEnvVar(k):
sanitized[k] = maskedValue
case g.anonymize:
sanitized[k] = g.anonymizer.AnonymizeString(val)
default:
sanitized[k] = val
}
}
params[jsonKeyServiceEnv] = sanitized
}
// isSensitiveEnvVar returns true for env var names that may contain secrets.
func isSensitiveEnvVar(key string) bool {
lower := strings.ToLower(key)
for _, s := range sensitiveEnvSubstrings {
if strings.Contains(lower, s) {
return true
}
}
return false
}
func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder) {
configContent.WriteString("NetBird Client Configuration:\n\n")

View File

@@ -1,8 +1,12 @@
package debug
import (
"archive/zip"
"bytes"
"encoding/json"
"net"
"os"
"path/filepath"
"strings"
"testing"
@@ -10,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -420,6 +425,226 @@ func TestAnonymizeNetworkMap(t *testing.T) {
}
}
func TestIsSensitiveEnvVar(t *testing.T) {
tests := []struct {
key string
sensitive bool
}{
{"NB_SETUP_KEY", true},
{"NB_API_TOKEN", true},
{"NB_CLIENT_SECRET", true},
{"NB_PASSWORD", true},
{"NB_CREDENTIAL", true},
{"NB_LOG_LEVEL", false},
{"NB_MANAGEMENT_URL", false},
{"NB_HOSTNAME", false},
{"HOME", false},
{"PATH", false},
}
for _, tt := range tests {
t.Run(tt.key, func(t *testing.T) {
assert.Equal(t, tt.sensitive, isSensitiveEnvVar(tt.key))
})
}
}
func TestSanitizeServiceEnvVars(t *testing.T) {
tests := []struct {
name string
anonymize bool
input map[string]any
check func(t *testing.T, params map[string]any)
}{
{
name: "no env vars key",
anonymize: false,
input: map[string]any{"management_url": "https://mgmt.example.com"},
check: func(t *testing.T, params map[string]any) {
t.Helper()
assert.Equal(t, "https://mgmt.example.com", params["management_url"], "non-env fields should be untouched")
_, ok := params[jsonKeyServiceEnv]
assert.False(t, ok, "service_env_vars should not be added")
},
},
{
name: "non-NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["HOME"], "non-NB_ var should be masked")
assert.Equal(t, maskedValue, env["PATH"], "non-NB_ var should be masked")
assert.Equal(t, "debug", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "sensitive NB vars are masked",
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"], "sensitive NB_ var should be masked")
assert.Equal(t, maskedValue, env["NB_API_TOKEN"], "sensitive NB_ var should be masked")
assert.Equal(t, "info", env["NB_LOG_LEVEL"], "safe NB_ var should pass through")
},
},
{
name: "safe NB vars anonymized when anonymize is true",
anonymize: true,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_MANAGEMENT_URL": "https://mgmt.example.com:443",
"NB_LOG_LEVEL": "debug",
"NB_SETUP_KEY": "secret",
"SOME_OTHER": "val",
},
},
check: func(t *testing.T, params map[string]any) {
t.Helper()
env := params[jsonKeyServiceEnv].(map[string]any)
// Safe NB_ values should be anonymized (not the original, not masked)
mgmtVal := env["NB_MANAGEMENT_URL"].(string)
assert.NotEqual(t, "https://mgmt.example.com:443", mgmtVal, "should be anonymized")
assert.NotEqual(t, maskedValue, mgmtVal, "should not be masked")
logVal := env["NB_LOG_LEVEL"].(string)
assert.NotEqual(t, maskedValue, logVal, "safe NB_ var should not be masked")
// Sensitive and non-NB_ still masked
assert.Equal(t, maskedValue, env["NB_SETUP_KEY"])
assert.Equal(t, maskedValue, env["SOME_OTHER"])
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
g := &BundleGenerator{
anonymize: tt.anonymize,
anonymizer: anonymizer,
}
g.sanitizeServiceEnvVars(tt.input)
tt.check(t, tt.input)
})
}
}
func TestAddServiceParams(t *testing.T) {
t.Run("missing service.json returns nil", func(t *testing.T) {
g := &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
}
origStateDir := configs.StateDir
configs.StateDir = t.TempDir()
t.Cleanup(func() { configs.StateDir = origStateDir })
err := g.addServiceParams()
assert.NoError(t, err)
})
t.Run("management_url anonymized when anonymize is true", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
jsonKeyServiceEnv: map[string]any{
"NB_LOG_LEVEL": "trace",
},
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: true,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
require.Len(t, zr.File, 1)
assert.Equal(t, serviceParamsBundle, zr.File[0].Name)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
mgmt := result[jsonKeyManagementURL].(string)
assert.NotEqual(t, "https://api.example.com:443", mgmt, "management_url should be anonymized")
assert.NotEmpty(t, mgmt)
env := result[jsonKeyServiceEnv].(map[string]any)
assert.NotEqual(t, maskedValue, env["NB_LOG_LEVEL"], "safe NB_ var should not be masked")
})
t.Run("management_url preserved when anonymize is false", func(t *testing.T) {
dir := t.TempDir()
origStateDir := configs.StateDir
configs.StateDir = dir
t.Cleanup(func() { configs.StateDir = origStateDir })
input := map[string]any{
jsonKeyManagementURL: "https://api.example.com:443",
}
data, err := json.Marshal(input)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, serviceParamsFile), data, 0600))
var buf bytes.Buffer
zw := zip.NewWriter(&buf)
g := &BundleGenerator{
anonymize: false,
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
archive: zw,
}
require.NoError(t, g.addServiceParams())
require.NoError(t, zw.Close())
zr, err := zip.NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
require.NoError(t, err)
rc, err := zr.File[0].Open()
require.NoError(t, err)
defer rc.Close()
var result map[string]any
require.NoError(t, json.NewDecoder(rc).Decode(&result))
assert.Equal(t, "https://api.example.com:443", result[jsonKeyManagementURL], "management_url should be preserved")
})
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{

View File

@@ -73,6 +73,9 @@ func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
return nil
}
w.response = m
if m.MsgHdr.Truncated {
w.SetMeta("truncated", "true")
}
return w.ResponseWriter.WriteMsg(m)
}
@@ -195,10 +198,14 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
startTime := time.Now()
requestID := resutil.GenerateRequestID()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": requestID,
"dns_id": fmt.Sprintf("%04x", r.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
question := r.Question[0]
qname := strings.ToLower(question.Name)
@@ -261,9 +268,9 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
meta += " " + k + "=" + v
}
logger.Tracef("response: domain=%s rcode=%s answers=%s%s took=%s",
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB%s took=%s",
qname, dns.RcodeToString[cw.response.Rcode], resutil.FormatAnswers(cw.response.Answer),
meta, time.Since(startTime))
cw.response.Len(), meta, time.Since(startTime))
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {

View File

@@ -1263,9 +1263,9 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
})
}
// TestLocalResolver_Stop tests cleanup on Stop
// TestLocalResolver_Stop tests cleanup on GracefullyStop
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
t.Run("GracefullyStop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1285,7 +1285,7 @@ func TestLocalResolver_Stop(t *testing.T) {
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
t.Run("GracefullyStop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
@@ -1299,7 +1299,7 @@ func TestLocalResolver_Stop(t *testing.T) {
resolver.Stop()
})
t.Run("Stop cancels in-flight external resolution", func(t *testing.T) {
t.Run("GracefullyStop cancels in-flight external resolution", func(t *testing.T) {
resolver := NewResolver()
lookupStarted := make(chan struct{})

View File

@@ -90,6 +90,11 @@ func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// Mock implementation - no-op
}
// SetFirewall mock implementation of SetFirewall from Server interface
func (m *MockServer) SetFirewall(Firewall) {
// Mock implementation - no-op
}
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op

View File

@@ -104,3 +104,23 @@ func (r *responseWriter) TsigTimersOnly(bool) {
// After a call to Hijack(), the DNS package will not do anything with the connection.
func (r *responseWriter) Hijack() {
}
// remoteAddrFromPacket extracts the source IP:port from a decoded packet for logging.
func remoteAddrFromPacket(packet gopacket.Packet) *net.UDPAddr {
var srcIP net.IP
if ipv4 := packet.Layer(layers.LayerTypeIPv4); ipv4 != nil {
srcIP = ipv4.(*layers.IPv4).SrcIP
} else if ipv6 := packet.Layer(layers.LayerTypeIPv6); ipv6 != nil {
srcIP = ipv6.(*layers.IPv6).SrcIP
}
var srcPort int
if udp := packet.Layer(layers.LayerTypeUDP); udp != nil {
srcPort = int(udp.(*layers.UDP).SrcPort)
}
if srcIP == nil {
return nil
}
return &net.UDPAddr{IP: srcIP, Port: srcPort}
}

View File

@@ -58,6 +58,7 @@ type Server interface {
UpdateServerConfig(domains dnsconfig.ServerDomains) error
PopulateManagementDomain(mgmtURL *url.URL) error
SetRouteChecker(func(netip.Addr) bool)
SetFirewall(Firewall)
}
type nsGroupsByDomain struct {
@@ -151,7 +152,7 @@ func NewDefaultServer(ctx context.Context, config DefaultServerConfig) (*Default
if config.WgInterface.IsUserspaceBind() {
dnsService = NewServiceViaMemory(config.WgInterface)
} else {
dnsService = newServiceViaListener(config.WgInterface, addrPort)
dnsService = newServiceViaListener(config.WgInterface, addrPort, nil)
}
server := newDefaultServer(ctx, config.WgInterface, dnsService, config.StatusRecorder, config.StateManager, config.DisableSys)
@@ -186,11 +187,16 @@ func NewDefaultServerIos(
ctx context.Context,
wgInterface WGIface,
iosDnsManager IosDnsManager,
hostsDnsList []netip.AddrPort,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
log.Debugf("iOS host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
return ds
}
@@ -374,6 +380,17 @@ func (s *DefaultServer) DnsIP() netip.Addr {
return s.service.RuntimeIP()
}
// SetFirewall sets the firewall used for DNS port DNAT rules.
// This must be called before Initialize when using the listener-based service,
// because the firewall is typically not available at construction time.
func (s *DefaultServer) SetFirewall(fw Firewall) {
if svc, ok := s.service.(*serviceViaListener); ok {
svc.listenerFlagLock.Lock()
svc.firewall = fw
svc.listenerFlagLock.Unlock()
}
}
// Stop stops the server
func (s *DefaultServer) Stop() {
s.probeMu.Lock()
@@ -395,8 +412,12 @@ func (s *DefaultServer) Stop() {
maps.Clear(s.extraDomains)
}
func (s *DefaultServer) disableDNS() error {
defer s.service.Stop()
func (s *DefaultServer) disableDNS() (retErr error) {
defer func() {
if err := s.service.Stop(); err != nil {
retErr = errors.Join(retErr, fmt.Errorf("stop DNS service: %w", err))
}
}()
if s.isUsingNoopHostManager() {
return nil

View File

@@ -476,8 +476,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
@@ -1071,7 +1071,7 @@ func (m *mockHandler) ID() types.HandlerID { return types.Hand
type mockService struct{}
func (m *mockService) Listen() error { return nil }
func (m *mockService) Stop() {}
func (m *mockService) Stop() error { return nil }
func (m *mockService) RuntimeIP() netip.Addr { return netip.MustParseAddr("127.0.0.1") }
func (m *mockService) RuntimePort() int { return 53 }
func (m *mockService) RegisterMux(string, dns.Handler) {}

View File

@@ -4,15 +4,25 @@ import (
"net/netip"
"github.com/miekg/dns"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
DefaultPort = 53
)
// Firewall provides DNAT capabilities for DNS port redirection.
// This is used when the DNS server cannot bind port 53 directly
// and needs firewall rules to redirect traffic.
type Firewall interface {
AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error
}
type service interface {
Listen() error
Stop()
Stop() error
RegisterMux(domain string, handler dns.Handler)
DeregisterMux(key string)
RuntimePort() int

View File

@@ -10,9 +10,13 @@ import (
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
)
@@ -31,25 +35,33 @@ type serviceViaListener struct {
dnsMux *dns.ServeMux
customAddr *netip.AddrPort
server *dns.Server
tcpServer *dns.Server
listenIP netip.Addr
listenPort uint16
listenerIsRunning bool
listenerFlagLock sync.Mutex
ebpfService ebpfMgr.Manager
firewall Firewall
tcpDNATConfigured bool
}
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort) *serviceViaListener {
func newServiceViaListener(wgIface WGIface, customAddr *netip.AddrPort, fw Firewall) *serviceViaListener {
mux := dns.NewServeMux()
s := &serviceViaListener{
wgInterface: wgIface,
dnsMux: mux,
customAddr: customAddr,
firewall: fw,
server: &dns.Server{
Net: "udp",
Handler: mux,
UDPSize: 65535,
},
tcpServer: &dns.Server{
Net: "tcp",
Handler: mux,
},
}
return s
@@ -70,43 +82,86 @@ func (s *serviceViaListener) Listen() error {
return fmt.Errorf("eval listen address: %w", err)
}
s.listenIP = s.listenIP.Unmap()
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
log.Debugf("starting dns on %s", s.server.Addr)
go func() {
s.setListenerStatus(true)
defer s.setListenerStatus(false)
addr := net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
s.server.Addr = addr
s.tcpServer.Addr = addr
err := s.server.ListenAndServe()
if err != nil {
log.Errorf("dns server running with %d port returned an error: %v. Will not retry", s.listenPort, err)
log.Debugf("starting dns on %s (UDP + TCP)", addr)
s.listenerIsRunning = true
go func() {
if err := s.server.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS UDP server on port %d: %v", s.listenPort, err)
}
s.listenerFlagLock.Lock()
unexpected := s.listenerIsRunning
s.listenerIsRunning = false
s.listenerFlagLock.Unlock()
if unexpected {
if err := s.tcpServer.Shutdown(); err != nil {
log.Debugf("failed to shutdown DNS TCP server: %v", err)
}
}
}()
go func() {
if err := s.tcpServer.ListenAndServe(); err != nil {
log.Errorf("failed to run DNS TCP server on port %d: %v", s.listenPort, err)
}
}()
// When eBPF redirects UDP port 53 to our listen port, TCP still needs
// a DNAT rule because eBPF only handles UDP.
if s.ebpfService != nil && s.firewall != nil && s.listenPort != DefaultPort {
if err := s.firewall.AddOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule, TCP DNS on port 53 will not work: %v", err)
} else {
s.tcpDNATConfigured = true
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", s.listenIP, DefaultPort, s.listenIP, s.listenPort)
}
}
return nil
}
func (s *serviceViaListener) Stop() {
func (s *serviceViaListener) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
return nil
}
s.listenerIsRunning = false
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.server.ShutdownContext(ctx)
if err != nil {
log.Errorf("stopping dns server listener returned an error: %v", err)
var merr *multierror.Error
if err := s.server.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS UDP server: %w", err))
}
if err := s.tcpServer.ShutdownContext(ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop DNS TCP server: %w", err))
}
if s.tcpDNATConfigured && s.firewall != nil {
if err := s.firewall.RemoveOutputDNAT(s.listenIP, firewall.ProtocolTCP, DefaultPort, s.listenPort); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
s.tcpDNATConfigured = false
}
if s.ebpfService != nil {
err = s.ebpfService.FreeDNSFwd()
if err != nil {
log.Errorf("stopping traffic forwarder returned an error: %v", err)
if err := s.ebpfService.FreeDNSFwd(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("stop traffic forwarder: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
@@ -133,12 +188,6 @@ func (s *serviceViaListener) RuntimeIP() netip.Addr {
return s.listenIP
}
func (s *serviceViaListener) setListenerStatus(running bool) {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
s.listenerIsRunning = running
}
// evalListenAddress figure out the listen address for the DNS server
// first check the 53 port availability on WG interface or lo, if not success
@@ -187,18 +236,28 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
}
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
probeListener, err := net.ListenUDP("udp", udpAddr)
addrPort := netip.AddrPortFrom(ip, uint16(port))
udpAddr := net.UDPAddrFromAddrPort(addrPort)
udpLn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
log.Warnf("binding dns on %s is not available, error: %s", addrString, err)
log.Warnf("binding dns UDP on %s is not available: %s", addrPort, err)
return false
}
err = probeListener.Close()
if err != nil {
log.Errorf("got an error closing the probe listener, error: %s", err)
if err := udpLn.Close(); err != nil {
log.Debugf("close UDP probe listener: %s", err)
}
tcpAddr := net.TCPAddrFromAddrPort(addrPort)
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
log.Warnf("binding dns TCP on %s is not available: %s", addrPort, err)
return false
}
if err := tcpLn.Close(); err != nil {
log.Debugf("close TCP probe listener: %s", err)
}
return true
}

View File

@@ -0,0 +1,86 @@
package dns
import (
"fmt"
"net"
"net/netip"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServiceViaListener_TCPAndUDP(t *testing.T) {
handler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("192.0.2.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// Create a service using a custom address to avoid needing root
svc := newServiceViaListener(nil, nil, nil)
svc.dnsMux.Handle(".", handler)
// Bind both transports up front to avoid TOCTOU races.
udpAddr := net.UDPAddrFromAddrPort(netip.AddrPortFrom(customIP, 0))
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
t.Skip("cannot bind to 127.0.0.153, skipping")
}
port := uint16(udpConn.LocalAddr().(*net.UDPAddr).Port)
tcpAddr := net.TCPAddrFromAddrPort(netip.AddrPortFrom(customIP, port))
tcpLn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
t.Skip("cannot bind TCP on same port, skipping")
}
addr := fmt.Sprintf("%s:%d", customIP, port)
svc.server.PacketConn = udpConn
svc.tcpServer.Listener = tcpLn
svc.listenIP = customIP
svc.listenPort = port
go func() {
if err := svc.server.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := svc.tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
svc.listenerIsRunning = true
defer func() {
require.NoError(t, svc.Stop())
}()
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
// Test UDP query
udpClient := &dns.Client{Net: "udp", Timeout: 2 * time.Second}
udpResp, _, err := udpClient.Exchange(q, addr)
require.NoError(t, err, "UDP query should succeed")
require.NotNil(t, udpResp)
require.NotEmpty(t, udpResp.Answer)
assert.Contains(t, udpResp.Answer[0].String(), "192.0.2.1", "UDP response should contain expected IP")
// Test TCP query
tcpClient := &dns.Client{Net: "tcp", Timeout: 2 * time.Second}
tcpResp, _, err := tcpClient.Exchange(q, addr)
require.NoError(t, err, "TCP query should succeed")
require.NotNil(t, tcpResp)
require.NotEmpty(t, tcpResp.Answer)
assert.Contains(t, tcpResp.Answer[0].String(), "192.0.2.1", "TCP response should contain expected IP")
}

View File

@@ -1,6 +1,7 @@
package dns
import (
"errors"
"fmt"
"net/netip"
"sync"
@@ -10,6 +11,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -18,7 +20,8 @@ type ServiceViaMemory struct {
dnsMux *dns.ServeMux
runtimeIP netip.Addr
runtimePort int
udpFilterHookID string
tcpDNS *tcpDNSServer
tcpHookSet bool
listenerIsRunning bool
listenerFlagLock sync.Mutex
}
@@ -28,14 +31,13 @@ func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{
return &ServiceViaMemory{
wgInterface: wgIface,
dnsMux: dns.NewServeMux(),
runtimeIP: lastIP,
runtimePort: DefaultPort,
}
return s
}
func (s *ServiceViaMemory) Listen() error {
@@ -46,10 +48,8 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
var err error
s.udpFilterHookID, err = s.filterDNSTraffic()
if err != nil {
return fmt.Errorf("filter dns traffice: %w", err)
if err := s.filterDNSTraffic(); err != nil {
return fmt.Errorf("filter dns traffic: %w", err)
}
s.listenerIsRunning = true
@@ -57,19 +57,29 @@ func (s *ServiceViaMemory) Listen() error {
return nil
}
func (s *ServiceViaMemory) Stop() {
func (s *ServiceViaMemory) Stop() error {
s.listenerFlagLock.Lock()
defer s.listenerFlagLock.Unlock()
if !s.listenerIsRunning {
return
return nil
}
if err := s.wgInterface.GetFilter().RemovePacketHook(s.udpFilterHookID); err != nil {
log.Errorf("unable to remove DNS packet hook: %s", err)
filter := s.wgInterface.GetFilter()
if filter != nil {
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
if s.tcpHookSet {
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), nil)
}
}
if s.tcpDNS != nil {
s.tcpDNS.Stop()
}
s.listenerIsRunning = false
return nil
}
func (s *ServiceViaMemory) RegisterMux(pattern string, handler dns.Handler) {
@@ -88,10 +98,18 @@ func (s *ServiceViaMemory) RuntimeIP() netip.Addr {
return s.runtimeIP
}
func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
func (s *ServiceViaMemory) filterDNSTraffic() error {
filter := s.wgInterface.GetFilter()
if filter == nil {
return "", fmt.Errorf("can't set DNS filter, filter not initialized")
return errors.New("DNS filter not initialized")
}
// Create TCP DNS server lazily here since the device may not exist at construction time.
if s.tcpDNS == nil {
if dev := s.wgInterface.GetDevice(); dev != nil {
// MTU only affects TCP segment sizing; DNS messages are small so this has no practical impact.
s.tcpDNS = newTCPDNSServer(s.dnsMux, dev.Device, s.runtimeIP, uint16(s.runtimePort), iface.DefaultMTU)
}
}
firstLayerDecoder := layers.LayerTypeIPv4
@@ -100,12 +118,16 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
}
hook := func(packetData []byte) bool {
// Decode the packet
packet := gopacket.NewPacket(packetData, firstLayerDecoder, gopacket.Default)
// Get the UDP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
udp := udpLayer.(*layers.UDP)
if udpLayer == nil {
return true
}
udp, ok := udpLayer.(*layers.UDP)
if !ok {
return true
}
msg := new(dns.Msg)
if err := msg.Unpack(udp.Payload); err != nil {
@@ -113,13 +135,30 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true
}
writer := responseWriter{
packet: packet,
device: s.wgInterface.GetDevice().Device,
dev := s.wgInterface.GetDevice()
if dev == nil {
return true
}
go s.dnsMux.ServeDNS(&writer, msg)
writer := &responseWriter{
remote: remoteAddrFromPacket(packet),
packet: packet,
device: dev.Device,
}
go s.dnsMux.ServeDNS(writer, msg)
return true
}
return filter.AddUDPPacketHook(false, s.runtimeIP, uint16(s.runtimePort), hook), nil
filter.SetUDPPacketHook(s.runtimeIP, uint16(s.runtimePort), hook)
if s.tcpDNS != nil {
tcpHook := func(packetData []byte) bool {
s.tcpDNS.InjectPacket(packetData)
return true
}
filter.SetTCPPacketHook(s.runtimeIP, uint16(s.runtimePort), tcpHook)
s.tcpHookSet = true
}
return nil
}

View File

@@ -0,0 +1,444 @@
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
)
const (
dnsTCPReceiveWindow = 8192
dnsTCPMaxInFlight = 16
dnsTCPIdleTimeout = 30 * time.Second
dnsTCPReadTimeout = 5 * time.Second
)
// tcpDNSServer is an on-demand TCP DNS server backed by a minimal gvisor stack.
// It is started lazily when a truncated DNS response is detected and shuts down
// after a period of inactivity to conserve resources.
type tcpDNSServer struct {
mu sync.Mutex
s *stack.Stack
ep *dnsEndpoint
mux *dns.ServeMux
tunDev tun.Device
ip netip.Addr
port uint16
mtu uint16
running bool
closed bool
timerID uint64
timer *time.Timer
}
func newTCPDNSServer(mux *dns.ServeMux, tunDev tun.Device, ip netip.Addr, port uint16, mtu uint16) *tcpDNSServer {
return &tcpDNSServer{
mux: mux,
tunDev: tunDev,
ip: ip,
port: port,
mtu: mtu,
}
}
// InjectPacket ensures the stack is running and delivers a raw IP packet into
// the gvisor stack for TCP processing. Combining both operations under a single
// lock prevents a race where the idle timer could stop the stack between
// start and delivery.
func (t *tcpDNSServer) InjectPacket(payload []byte) {
t.mu.Lock()
defer t.mu.Unlock()
if t.closed {
return
}
if !t.running {
if err := t.startLocked(); err != nil {
log.Errorf("failed to start TCP DNS stack: %v", err)
return
}
t.running = true
log.Debugf("TCP DNS stack started on %s:%d (triggered by %s)", t.ip, t.port, srcAddrFromPacket(payload))
}
t.resetTimerLocked()
ep := t.ep
if ep == nil || ep.dispatcher == nil {
return
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
// DeliverNetworkPacket takes ownership of the packet buffer; do not DecRef.
ep.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
// Stop tears down the gvisor stack and releases resources permanently.
// After Stop, InjectPacket becomes a no-op.
func (t *tcpDNSServer) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.stopLocked()
t.closed = true
}
func (t *tcpDNSServer) startLocked() error {
// TODO: add ipv6.NewProtocol when IPv6 overlay support lands.
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
HandleLocal: false,
})
nicID := tcpip.NICID(1)
ep := &dnsEndpoint{
tunDev: t.tunDev,
}
ep.mtu.Store(uint32(t.mtu))
if err := s.CreateNIC(nicID, ep); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(t.ip.AsSlice()),
PrefixLen: 32,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("add protocol address: %s", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
s.Close()
s.Wait()
return fmt.Errorf("set spoofing: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
s.Close()
s.Wait()
return fmt.Errorf("create default subnet: %w", err)
}
s.SetRouteTable([]tcpip.Route{
{Destination: defaultSubnet, NIC: nicID},
})
tcpFwd := tcp.NewForwarder(s, dnsTCPReceiveWindow, dnsTCPMaxInFlight, func(r *tcp.ForwarderRequest) {
t.handleTCPDNS(r)
})
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
t.s = s
t.ep = ep
return nil
}
func (t *tcpDNSServer) stopLocked() {
if !t.running {
return
}
if t.timer != nil {
t.timer.Stop()
t.timer = nil
}
if t.s != nil {
t.s.Close()
t.s.Wait()
t.s = nil
}
t.ep = nil
t.running = false
log.Debugf("TCP DNS stack stopped")
}
func (t *tcpDNSServer) resetTimerLocked() {
if t.timer != nil {
t.timer.Stop()
}
t.timerID++
id := t.timerID
t.timer = time.AfterFunc(dnsTCPIdleTimeout, func() {
t.mu.Lock()
defer t.mu.Unlock()
// Only stop if this timer is still the active one.
// A racing InjectPacket may have replaced it.
if t.timerID != id {
return
}
t.stopLocked()
})
}
func (t *tcpDNSServer) handleTCPDNS(r *tcp.ForwarderRequest) {
id := r.ID()
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
log.Debugf("TCP DNS: failed to create endpoint: %v", epErr)
r.Complete(true)
return
}
r.Complete(false)
conn := gonet.NewTCPConn(&wq, ep)
defer func() {
if err := conn.Close(); err != nil {
log.Tracef("TCP DNS: close conn: %v", err)
}
}()
// Reset idle timer on activity
t.mu.Lock()
t.resetTimerLocked()
t.mu.Unlock()
localAddr := &net.TCPAddr{
IP: id.LocalAddress.AsSlice(),
Port: int(id.LocalPort),
}
remoteAddr := &net.TCPAddr{
IP: id.RemoteAddress.AsSlice(),
Port: int(id.RemotePort),
}
for {
if err := conn.SetReadDeadline(time.Now().Add(dnsTCPReadTimeout)); err != nil {
log.Debugf("TCP DNS: set deadline for %s: %v", remoteAddr, err)
break
}
msg, err := readTCPDNSMessage(conn)
if err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
log.Debugf("TCP DNS: read from %s: %v", remoteAddr, err)
}
break
}
writer := &tcpResponseWriter{
conn: conn,
localAddr: localAddr,
remoteAddr: remoteAddr,
}
t.mux.ServeDNS(writer, msg)
}
}
// dnsEndpoint implements stack.LinkEndpoint for writing packets back via the tun device.
type dnsEndpoint struct {
dispatcher stack.NetworkDispatcher
tunDev tun.Device
mtu atomic.Uint32
}
func (e *dnsEndpoint) Attach(dispatcher stack.NetworkDispatcher) { e.dispatcher = dispatcher }
func (e *dnsEndpoint) IsAttached() bool { return e.dispatcher != nil }
func (e *dnsEndpoint) MTU() uint32 { return e.mtu.Load() }
func (e *dnsEndpoint) Capabilities() stack.LinkEndpointCapabilities { return stack.CapabilityNone }
func (e *dnsEndpoint) MaxHeaderLength() uint16 { return 0 }
func (e *dnsEndpoint) LinkAddress() tcpip.LinkAddress { return "" }
func (e *dnsEndpoint) Wait() { /* no async work */ }
func (e *dnsEndpoint) ARPHardwareType() header.ARPHardwareType { return header.ARPHardwareNone }
func (e *dnsEndpoint) AddHeader(*stack.PacketBuffer) { /* IP-level endpoint, no link header */ }
func (e *dnsEndpoint) ParseHeader(*stack.PacketBuffer) bool { return true }
func (e *dnsEndpoint) Close() { /* lifecycle managed by tcpDNSServer */ }
func (e *dnsEndpoint) SetLinkAddress(tcpip.LinkAddress) { /* no link address for tun */ }
func (e *dnsEndpoint) SetMTU(mtu uint32) { e.mtu.Store(mtu) }
func (e *dnsEndpoint) SetOnCloseAction(func()) { /* not needed */ }
const tunPacketOffset = 40
func (e *dnsEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
raw := data.AsSlice()
buf := make([]byte, tunPacketOffset, tunPacketOffset+len(raw))
buf = append(buf, raw...)
data.Release()
if _, err := e.tunDev.Write([][]byte{buf}, tunPacketOffset); err != nil {
log.Tracef("TCP DNS endpoint: failed to write packet: %v", err)
continue
}
written++
}
return written, nil
}
// tcpResponseWriter implements dns.ResponseWriter for TCP DNS connections.
type tcpResponseWriter struct {
conn *gonet.TCPConn
localAddr net.Addr
remoteAddr net.Addr
}
func (w *tcpResponseWriter) LocalAddr() net.Addr {
return w.localAddr
}
func (w *tcpResponseWriter) RemoteAddr() net.Addr {
return w.remoteAddr
}
func (w *tcpResponseWriter) WriteMsg(msg *dns.Msg) error {
data, err := msg.Pack()
if err != nil {
return fmt.Errorf("pack: %w", err)
}
// DNS TCP: 2-byte length prefix + message
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err = w.conn.Write(buf); err != nil {
return err
}
return nil
}
func (w *tcpResponseWriter) Write(data []byte) (int, error) {
buf := make([]byte, 2+len(data))
buf[0] = byte(len(data) >> 8)
buf[1] = byte(len(data))
copy(buf[2:], data)
if _, err := w.conn.Write(buf); err != nil {
return 0, err
}
return len(data), nil
}
func (w *tcpResponseWriter) Close() error {
return w.conn.Close()
}
func (w *tcpResponseWriter) TsigStatus() error { return nil }
func (w *tcpResponseWriter) TsigTimersOnly(bool) { /* TSIG not supported */ }
func (w *tcpResponseWriter) Hijack() { /* not supported */ }
// readTCPDNSMessage reads a single DNS message from a TCP connection (length-prefixed).
func readTCPDNSMessage(conn *gonet.TCPConn) (*dns.Msg, error) {
// DNS over TCP uses a 2-byte length prefix
lenBuf := make([]byte, 2)
if _, err := io.ReadFull(conn, lenBuf); err != nil {
return nil, fmt.Errorf("read length: %w", err)
}
msgLen := int(lenBuf[0])<<8 | int(lenBuf[1])
if msgLen == 0 || msgLen > 65535 {
return nil, fmt.Errorf("invalid message length: %d", msgLen)
}
msgBuf := make([]byte, msgLen)
if _, err := io.ReadFull(conn, msgBuf); err != nil {
return nil, fmt.Errorf("read message: %w", err)
}
msg := new(dns.Msg)
if err := msg.Unpack(msgBuf); err != nil {
return nil, fmt.Errorf("unpack: %w", err)
}
return msg, nil
}
// srcAddrFromPacket extracts the source IP:port from a raw IP+TCP packet for logging.
// Supports both IPv4 and IPv6.
func srcAddrFromPacket(pkt []byte) netip.AddrPort {
if len(pkt) == 0 {
return netip.AddrPort{}
}
srcIP, transportOffset := srcIPFromPacket(pkt)
if !srcIP.IsValid() || len(pkt) < transportOffset+2 {
return netip.AddrPort{}
}
srcPort := uint16(pkt[transportOffset])<<8 | uint16(pkt[transportOffset+1])
return netip.AddrPortFrom(srcIP.Unmap(), srcPort)
}
func srcIPFromPacket(pkt []byte) (netip.Addr, int) {
switch header.IPVersion(pkt) {
case 4:
return srcIPv4(pkt)
case 6:
return srcIPv6(pkt)
default:
return netip.Addr{}, 0
}
}
func srcIPv4(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv4MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv4(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, int(hdr.HeaderLength())
}
func srcIPv6(pkt []byte) (netip.Addr, int) {
if len(pkt) < header.IPv6MinimumSize {
return netip.Addr{}, 0
}
hdr := header.IPv6(pkt)
src := hdr.SourceAddress()
ip, ok := netip.AddrFromSlice(src.AsSlice())
if !ok {
return netip.Addr{}, 0
}
return ip, header.IPv6MinimumSize
}

View File

@@ -41,10 +41,61 @@ const (
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
const (
protoUDP = "udp"
protoTCP = "tcp"
)
type dnsProtocolKey struct{}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
}
// dnsProtocolFromContext retrieves the inbound DNS protocol from context.
func dnsProtocolFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if v, ok := ctx.Value(dnsProtocolKey{}).(string); ok {
return v
}
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
// setUpstreamProtocol sets the upstream protocol on the result holder in context, if present.
func setUpstreamProtocol(ctx context.Context, protocol string) {
if ctx == nil {
return
}
if r, ok := ctx.Value(upstreamProtocolKey{}).(*upstreamProtocolResult); ok && r != nil {
r.protocol = protocol
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
@@ -138,7 +189,16 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
ok, failures := u.tryUpstreamServers(w, r, logger)
// Propagate inbound protocol so upstream exchange can use TCP directly
// when the request came in over TCP.
ctx := u.ctx
if addr := w.RemoteAddr(); addr != nil {
network := addr.Network()
ctx = contextWithDNSProtocol(ctx, network)
resutil.SetMeta(w, "protocol", network)
}
ok, failures := u.tryUpstreamServers(ctx, w, r, logger)
if len(failures) > 0 {
u.logUpstreamFailures(r.Question[0].Name, failures, ok, logger)
}
@@ -153,7 +213,7 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
}
func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
@@ -168,7 +228,7 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(w, r, upstream, timeout, logger); failure != nil {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
@@ -178,15 +238,17 @@ func (u *upstreamResolverBase) tryUpstreamServers(w dns.ResponseWriter, r *dns.M
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(u.ctx, timeout)
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
@@ -203,7 +265,7 @@ func (u *upstreamResolverBase) queryUpstream(w dns.ResponseWriter, r *dns.Msg, u
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, logger)
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
@@ -220,10 +282,13 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, logger *log.Entry) bool {
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
}
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
@@ -428,13 +493,42 @@ func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalC
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
return int(opt.UDPSize())
}
return dns.MinMsgSize
}
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// MTU - ip + udp headers
// Note: this could be sent out on an interface that is not ours, but higher MTU settings could break truncation handling.
client.UDPSize = uint16(currentMTU - (60 + 8))
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
setUpstreamProtocol(ctx, protoTCP)
return rm, t, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than our read buffer.
// Note: the query could be sent out on an interface that is not ours,
// but higher MTU settings could break truncation handling.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
client.UDPSize = maxUDPPayload
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
@@ -453,25 +547,32 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
}
if rm == nil || !rm.MsgHdr.Truncated {
setUpstreamProtocol(ctx, protoUDP)
return rm, t, nil
}
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// TODO: if the upstream's truncated UDP response already contains more
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
client.Net = "tcp"
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, t, nil
}
@@ -479,18 +580,46 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
// If request came in over TCP, go straight to TCP upstream
if dnsProtocolFromContext(ctx) == protoTCP {
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
return rm, nil
}
clientMaxSize := clientUDPMaxSize(r)
// Cap EDNS0 to our tunnel MTU so the upstream doesn't send a
// response larger than what we can read over UDP.
maxUDPPayload := uint16(currentMTU - ipUDPHeaderSize)
if opt := r.IsEdns0(); opt != nil && opt.UDPSize() > maxUDPPayload {
opt.SetUDPSize(maxUDPPayload)
}
reply, err := netstackExchange(ctx, nsNet, r, upstream, protoUDP)
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
rm, err := netstackExchange(ctx, nsNet, r, upstream, protoTCP)
if err != nil {
return nil, err
}
setUpstreamProtocol(ctx, protoTCP)
if rm.Len() > clientMaxSize {
rm.Truncate(clientMaxSize)
}
return rm, nil
}
setUpstreamProtocol(ctx, protoUDP)
return reply, nil
}
@@ -511,7 +640,7 @@ func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upst
}
}
dnsConn := &dns.Conn{Conn: conn}
dnsConn := &dns.Conn{Conn: conn, UDPSize: uint16(currentMTU - ipUDPHeaderSize)}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)

View File

@@ -51,7 +51,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
upstreamExchangeClient := &dns.Client{
Timeout: ClientTimeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
}
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
@@ -76,7 +76,7 @@ func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream stri
Timeout: timeout,
}
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
return ExchangeWithFallback(ctx, upstreamExchangeClient, r, upstream)
}
func (u *upstreamResolver) isLocalResolver(upstream string) bool {

View File

@@ -475,3 +475,298 @@ func TestFormatFailures(t *testing.T) {
})
}
}
func TestDNSProtocolContext(t *testing.T) {
t.Run("roundtrip udp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoUDP)
assert.Equal(t, protoUDP, dnsProtocolFromContext(ctx))
})
t.Run("roundtrip tcp", func(t *testing.T) {
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
assert.Equal(t, protoTCP, dnsProtocolFromContext(ctx))
})
t.Run("missing returns empty", func(t *testing.T) {
assert.Equal(t, "", dnsProtocolFromContext(context.Background()))
})
}
func TestExchangeWithFallback_TCPContext(t *testing.T) {
// Start a local DNS server that responds on TCP only
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpServer := &dns.Server{
Addr: "127.0.0.1:0",
Net: "tcp",
Handler: tcpHandler,
}
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer.Listener = tcpLn
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// With TCP context, should connect directly via TCP without trying UDP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.1")
}
func TestExchangeWithFallback_UDPFallbackToTCP(t *testing.T) {
// UDP handler returns a truncated response to trigger TCP retry.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
// TCP handler returns the full answer.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.3"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{
PacketConn: udpPC,
Net: "udp",
Handler: udpHandler,
}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := udpServer.ActivateAndServe(); err != nil {
t.Logf("udp server: %v", err)
}
}()
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
}()
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err, "should fall back to TCP after truncated UDP response")
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer, "TCP response should contain the full answer")
assert.Contains(t, rm.Answer[0].String(), "10.0.0.3")
assert.False(t, rm.Truncated, "TCP response should not be truncated")
}
func TestExchangeWithFallback_TCPContextSkipsUDP(t *testing.T) {
// Start only a TCP server (no UDP). With TCP context it should succeed.
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.2"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpLn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
tcpServer := &dns.Server{
Listener: tcpLn,
Net: "tcp",
Handler: tcpHandler,
}
go func() {
if err := tcpServer.ActivateAndServe(); err != nil {
t.Logf("tcp server: %v", err)
}
}()
defer func() {
_ = tcpServer.Shutdown()
}()
upstream := tcpLn.Addr().String()
// TCP context: should skip UDP entirely and go directly to TCP
ctx := contextWithDNSProtocol(context.Background(), protoTCP)
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
rm, _, err := ExchangeWithFallback(ctx, client, r, upstream)
require.NoError(t, err)
require.NotNil(t, rm)
require.NotEmpty(t, rm.Answer)
assert.Contains(t, rm.Answer[0].String(), "10.0.0.2")
// Without TCP context, trying to reach a TCP-only server via UDP should fail
ctx2 := context.Background()
client2 := &dns.Client{Timeout: 500 * time.Millisecond}
_, _, err = ExchangeWithFallback(ctx2, client2, r, upstream)
assert.Error(t, err, "should fail when no UDP server and no TCP context")
}
func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
}
m := new(dns.Msg)
m.SetReply(r)
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1"),
})
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
t.Cleanup(func() { _ = udpServer.Shutdown() })
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
"upstream should see capped EDNS0, not the client's 4096")
}
func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
// When the client advertises a large EDNS0 (4096) and the upstream
// truncates, the TCP response should NOT be truncated since the full
// answer fits within the client's original buffer.
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Truncated = true
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
tcpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
// Add enough records to exceed MTU but fit within 4096
for i := range 20 {
m.Answer = append(m.Answer, &dns.TXT{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 60},
Txt: []string{fmt.Sprintf("record-%d-padding-data-to-make-it-longer", i)},
})
}
if err := w.WriteMsg(m); err != nil {
t.Logf("write msg: %v", err)
}
})
udpPC, err := net.ListenPacket("udp", "127.0.0.1:0")
require.NoError(t, err)
addr := udpPC.LocalAddr().String()
udpServer := &dns.Server{PacketConn: udpPC, Net: "udp", Handler: udpHandler}
tcpLn, err := net.Listen("tcp", addr)
require.NoError(t, err)
tcpServer := &dns.Server{Listener: tcpLn, Net: "tcp", Handler: tcpHandler}
go func() { _ = udpServer.ActivateAndServe() }()
go func() { _ = tcpServer.ActivateAndServe() }()
t.Cleanup(func() {
_ = udpServer.Shutdown()
_ = tcpServer.Shutdown()
})
ctx := context.Background()
client := &dns.Client{Timeout: 2 * time.Second}
// Client with large buffer: should get all records without truncation
r := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r.SetEdns0(4096, false)
rm, _, err := ExchangeWithFallback(ctx, client, r, addr)
require.NoError(t, err)
require.NotNil(t, rm)
assert.Len(t, rm.Answer, 20, "large EDNS0 client should get all records")
assert.False(t, rm.Truncated, "response should not be truncated for large buffer client")
// Client with small buffer: should get truncated response
r2 := new(dns.Msg).SetQuestion("example.com.", dns.TypeTXT)
r2.SetEdns0(512, false)
rm2, _, err := ExchangeWithFallback(ctx, &dns.Client{Timeout: 2 * time.Second}, r2, addr)
require.NoError(t, err)
require.NotNil(t, rm2)
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}

View File

@@ -237,8 +237,8 @@ func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, re
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
logger.Tracef("response: domain=%s rcode=%s answers=%s size=%dB took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), resp.Len(), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
@@ -263,20 +263,28 @@ func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
startTime := time.Now()
logger := log.WithFields(log.Fields{
fields := log.Fields{
"request_id": resutil.GenerateRequestID(),
"dns_id": fmt.Sprintf("%04x", query.Id),
})
}
if addr := w.RemoteAddr(); addr != nil {
fields["client"] = addr.String()
}
logger := log.WithFields(fields)
f.handleDNSQuery(logger, w, query, startTime)
}

View File

@@ -46,6 +46,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
@@ -116,11 +117,13 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
DNSRouteInterval time.Duration
@@ -196,6 +199,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
statusRecorder *peer.Status
@@ -210,9 +214,10 @@ type Engine struct {
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
relayManager *relayClient.Manager
stateManager *statemanager.Manager
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
@@ -259,26 +264,27 @@ func NewEngine(
mobileDep MobileDependency,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient,
relayManager: services.RelayManager,
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
jobExecutor: jobexec.NewExecutor(),
clientMetrics: services.ClientMetrics,
updateManager: services.UpdateManager,
}
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
@@ -307,6 +313,10 @@ func (e *Engine) Stop() error {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -500,7 +510,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetClientRoutes() {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
@@ -521,6 +531,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return err
}
// Inject firewall into DNS server now that it's available.
// The DNS server is created before the firewall because the route manager
// depends on the DNS server, and the firewall depends on the wg interface.
e.dnsServer.SetFirewall(e.firewall)
e.udpMux, err = e.wgInterface.Up()
if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
@@ -532,6 +547,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// conntrack entries from being created before the rules are in place
e.setupWGProxyNoTrack()
// Start after interface is up since port may have been resolved from 0 or changed if occupied
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
e.portForwardManager.Start(e.ctx, uint16(e.config.WgPort))
}()
// Set the WireGuard interface for rosenpass after interface is up
if e.rpManager != nil {
e.rpManager.SetInterface(e.wgInterface)
@@ -982,6 +1004,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -994,6 +1017,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -1021,6 +1045,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1122,6 +1150,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1134,6 +1163,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1308,6 +1338,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// VNC auth: use dedicated VNCAuth if present.
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
e.updateVNCServerAuth(vncAuth)
}
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1535,12 +1570,13 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
}
serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
MetricsRecorder: e.clientMetrics,
StatusRecorder: e.statusRecorder,
Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager,
SrWatcher: e.srWatcher,
PortForwardManager: e.portForwardManager,
MetricsRecorder: e.clientMetrics,
}
peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil {
@@ -1697,6 +1733,12 @@ func (e *Engine) close() {
if e.rpManager != nil {
_ = e.rpManager.Close()
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := e.portForwardManager.GracefullyStop(ctx); err != nil {
log.Warnf("failed to gracefully stop port forwarding manager: %s", err)
}
}
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
@@ -1710,6 +1752,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1722,6 +1765,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)
@@ -1800,7 +1844,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:
@@ -1837,6 +1881,11 @@ func (e *Engine) GetExposeManager() *expose.Manager {
return e.exposeManager
}
// IsBlockInbound returns whether inbound connections are blocked.
func (e *Engine) IsBlockInbound() bool {
return e.config.BlockInbound
}
// GetClientMetrics returns the client metrics
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics

View File

@@ -828,7 +828,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1035,7 +1035,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
@@ -1538,13 +1538,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
return nil, err
}
publicKey, err := mgmtClient.GetServerPublicKey()
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil, nil)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
@@ -1566,7 +1561,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,

View File

@@ -0,0 +1,247 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
vncExternalPort uint16 = 5900
vncInternalPort uint16 = 25900
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
// sshConf provides the JWT identity provider config (shared with SSH).
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
// Update JWT config on existing server in case management sent new config.
e.updateVNCServerJWT(sshConf)
return nil
}
return e.startVNCServer(sshConf)
}
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector := newPlatformVNC()
if capturer == nil || injector == nil {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
srv := vncserver.New(capturer, injector, "")
if vncNeedsServiceMode() {
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
srv.SetServiceMode(true)
}
// Configure VNC authentication.
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
log.Info("VNC: authentication disabled by config")
srv.SetDisableAuth(true)
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
srv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
srv.SetNetstackNet(netstackNet)
}
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// updateVNCServerJWT configures the JWT validation for the VNC server using
// the same JWT config as SSH (same identity provider).
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
vncSrv.SetDisableAuth(true)
return
}
protoJWT := sshConf.GetJwtConfig()
if protoJWT == nil {
return
}
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if vncAuth == nil || e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
UserIDClaim: vncAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
})
}
// GetVNCServerStatus returns whether the VNC server is running.
func (e *Engine) GetVNCServerStatus() bool {
return e.vncSrv != nil
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}

View File

@@ -0,0 +1,23 @@
//go:build darwin && !ios
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin && !freebsd && !(linux && !android)
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return nil, nil
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -0,0 +1,23 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewX11Poller("")
injector, err := vncserver.NewX11InputInjector("")
if err != nil {
log.Debugf("VNC: X11 input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -22,6 +22,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -45,6 +46,7 @@ type ServiceDependencies struct {
RelayManager *relayClient.Manager
SrWatcher *guard.SRWatcher
PeerConnDispatcher *dispatcher.ConnectionDispatcher
PortForwardManager *portforward.Manager
MetricsRecorder MetricsRecorder
}
@@ -87,16 +89,17 @@ type ConnConfig struct {
}
type Conn struct {
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
Log *log.Entry
mu sync.Mutex
ctx context.Context
ctxCancel context.CancelFunc
config ConnConfig
statusRecorder *Status
signaler *Signaler
iFaceDiscover stdnet.ExternalIFaceDiscover
relayManager *relayClient.Manager
srWatcher *guard.SRWatcher
portForwardManager *portforward.Manager
onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
onDisconnected func(remotePeer string)
@@ -145,19 +148,20 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
Log: connLog,
config: config,
statusRecorder: services.StatusRecorder,
signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager,
srWatcher: services.SrWatcher,
portForwardManager: services.PortForwardManager,
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
return conn, nil

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
)
@@ -61,6 +62,9 @@ type WorkerICE struct {
// we record the last known state of the ICE agent to avoid duplicate on disconnected events
lastKnownState ice.ConnectionState
// portForwardAttempted tracks if we've already tried port forwarding this session
portForwardAttempted bool
}
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
@@ -214,6 +218,8 @@ func (w *WorkerICE) Close() {
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*icemaker.ThreadSafeAgent, error) {
w.portForwardAttempted = false
agent, err := icemaker.NewAgent(w.ctx, w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
@@ -370,6 +376,93 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if candidate.Type() == ice.CandidateTypeServerReflexive {
w.injectPortForwardedCandidate(candidate)
}
}
// injectPortForwardedCandidate signals an additional candidate using the pre-created port mapping.
func (w *WorkerICE) injectPortForwardedCandidate(srflxCandidate ice.Candidate) {
pfManager := w.conn.portForwardManager
if pfManager == nil {
return
}
mapping := pfManager.GetMapping()
if mapping == nil {
return
}
w.muxAgent.Lock()
if w.portForwardAttempted {
w.muxAgent.Unlock()
return
}
w.portForwardAttempted = true
w.muxAgent.Unlock()
forwardedCandidate, err := w.createForwardedCandidate(srflxCandidate, mapping)
if err != nil {
w.log.Warnf("create forwarded candidate: %v", err)
return
}
w.log.Debugf("injecting port-forwarded candidate: %s (mapping: %d -> %d via %s, priority: %d)",
forwardedCandidate.String(), mapping.InternalPort, mapping.ExternalPort, mapping.NATType, forwardedCandidate.Priority())
go func() {
if err := w.signaler.SignalICECandidate(forwardedCandidate, w.config.Key); err != nil {
w.log.Errorf("signal port-forwarded candidate: %v", err)
}
}()
}
// createForwardedCandidate creates a new server reflexive candidate with the forwarded port.
// It uses the NAT gateway's external IP with the forwarded port.
func (w *WorkerICE) createForwardedCandidate(srflxCandidate ice.Candidate, mapping *portforward.Mapping) (ice.Candidate, error) {
var externalIP string
if mapping.ExternalIP != nil && !mapping.ExternalIP.IsUnspecified() {
externalIP = mapping.ExternalIP.String()
} else {
// Fallback to STUN-discovered address if NAT didn't provide external IP
externalIP = srflxCandidate.Address()
}
// Per RFC 8445, the related address for srflx is the base (host candidate address).
// If the original srflx has unspecified related address, use its own address as base.
relAddr := srflxCandidate.RelatedAddress().Address
if relAddr == "" || relAddr == "0.0.0.0" || relAddr == "::" {
relAddr = srflxCandidate.Address()
}
// Arbitrary +1000 boost on top of RFC 8445 priority to favor port-forwarded candidates
// over regular srflx during ICE connectivity checks.
priority := srflxCandidate.Priority() + 1000
candidate, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
Network: srflxCandidate.NetworkType().String(),
Address: externalIP,
Port: int(mapping.ExternalPort),
Component: srflxCandidate.Component(),
Priority: priority,
RelAddr: relAddr,
RelPort: int(mapping.InternalPort),
})
if err != nil {
return nil, fmt.Errorf("create candidate: %w", err)
}
for _, e := range srflxCandidate.Extensions() {
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = srflxCandidate.ID()
}
if err := candidate.AddExtension(e); err != nil {
return nil, fmt.Errorf("add extension: %w", err)
}
}
return candidate, nil
}
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
@@ -411,10 +504,10 @@ func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
if !lok || !rok {
continue
}
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
w.log.Debugf("successful ICE path %s: [%s %s %s:%d] <-> [%s %s %s:%d] rtt=%.3fms",
sessionID,
local.NetworkType(), local.Type(), local.Address(),
remote.NetworkType(), remote.Type(), remote.Address(),
local.NetworkType(), local.Type(), local.Address(), local.Port(),
remote.NetworkType(), remote.Type(), remote.Address(), remote.Port(),
stat.CurrentRoundTripTime*1000)
}
}

View File

@@ -0,0 +1,26 @@
package portforward
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
)
func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
return false
}
return disabled
}

View File

@@ -0,0 +1,280 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"regexp"
"sync"
"time"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
)
const (
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
// allowing for whitespace/newlines between tags from different router firmware.
var upnpErrPermanentLeaseOnly = regexp.MustCompile(`<errorCode>\s*725\s*</errorCode>`)
// Mapping represents an active NAT port mapping.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
// TTL is the lease duration. Zero means a permanent lease that never expires.
TTL time.Duration
}
// TODO: persist mapping state for crash recovery cleanup of permanent leases.
// Currently not done because State.Cleanup requires NAT gateway re-discovery,
// which blocks startup for ~10s when no gateway is present (affects all clients).
type Manager struct {
cancel context.CancelFunc
mapping *Mapping
mappingLock sync.Mutex
wgPort uint16
done chan struct{}
stopCtx chan context.Context
// protect exported functions
mu sync.Mutex
}
// NewManager creates a new port forwarding manager.
func NewManager() *Manager {
return &Manager{
stopCtx: make(chan context.Context, 1),
}
}
func (m *Manager) Start(ctx context.Context, wgPort uint16) {
m.mu.Lock()
if m.cancel != nil {
m.mu.Unlock()
return
}
if isDisabledByEnv() {
log.Infof("NAT port mapper disabled via %s", envDisableNATMapper)
m.mu.Unlock()
return
}
if wgPort == 0 {
log.Warnf("invalid WireGuard port 0; NAT mapping disabled")
m.mu.Unlock()
return
}
m.wgPort = wgPort
m.done = make(chan struct{})
defer close(m.done)
ctx, m.cancel = context.WithCancel(ctx)
m.mu.Unlock()
gateway, mapping, err := m.setup(ctx)
if err != nil {
log.Infof("port forwarding setup: %v", err)
return
}
m.mappingLock.Lock()
m.mapping = mapping
m.mappingLock.Unlock()
m.renewLoop(ctx, gateway, mapping.TTL)
select {
case cleanupCtx := <-m.stopCtx:
// block the Start while cleaned up gracefully
m.cleanup(cleanupCtx, gateway)
default:
// return Start immediately and cleanup in background
cleanupCtx, cleanupCancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
defer cleanupCancel()
m.cleanup(cleanupCtx, gateway)
}()
}
}
// GetMapping returns the current mapping if ready, nil otherwise
func (m *Manager) GetMapping() *Mapping {
m.mappingLock.Lock()
defer m.mappingLock.Unlock()
if m.mapping == nil {
return nil
}
mapping := *m.mapping
return &mapping
}
// GracefullyStop cancels the manager and attempts to delete the port mapping.
// After GracefullyStop returns, the manager cannot be restarted.
func (m *Manager) GracefullyStop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.cancel == nil {
return nil
}
// Send cleanup context before cancelling, so Start picks it up after renewLoop exits.
m.startTearDown(ctx)
m.cancel()
m.cancel = nil
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
return nil
}
}
func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := nat.DiscoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
log.Infof("discovered NAT gateway: %s", gateway.Type())
mapping, err := m.createMapping(ctx, gateway)
if err != nil {
return nil, nil, fmt.Errorf("create port mapping: %w", err)
}
return gateway, mapping, nil
}
func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
ttl := defaultMappingTTL
externalPort, err := gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
if err != nil {
if !isPermanentLeaseRequired(err) {
return nil, err
}
log.Infof("gateway only supports permanent leases, retrying with indefinite duration")
ttl = 0
externalPort, err = gateway.AddPortMapping(ctx, "udp", int(m.wgPort), mappingDescription, ttl)
if err != nil {
return nil, err
}
}
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
Protocol: "udp",
InternalPort: m.wgPort,
ExternalPort: uint16(externalPort),
ExternalIP: externalIP,
NATType: gateway.Type(),
TTL: ttl,
}
log.Infof("created port mapping: %d -> %d via %s (external IP: %s)",
m.wgPort, externalPort, gateway.Type(), externalIP)
return mapping, nil
}
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
return
}
ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
}
}
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
externalPort, err := gateway.AddPortMapping(ctx, m.mapping.Protocol, int(m.mapping.InternalPort), mappingDescription, m.mapping.TTL)
if err != nil {
return fmt.Errorf("add port mapping: %w", err)
}
if uint16(externalPort) != m.mapping.ExternalPort {
log.Warnf("external port changed on renewal: %d -> %d (candidate may be stale)", m.mapping.ExternalPort, externalPort)
m.mappingLock.Lock()
m.mapping.ExternalPort = uint16(externalPort)
m.mappingLock.Unlock()
}
log.Debugf("renewed port mapping: %d -> %d", m.mapping.InternalPort, m.mapping.ExternalPort)
return nil
}
func (m *Manager) cleanup(ctx context.Context, gateway nat.NAT) {
m.mappingLock.Lock()
mapping := m.mapping
m.mapping = nil
m.mappingLock.Unlock()
if mapping == nil {
return
}
if err := gateway.DeletePortMapping(ctx, mapping.Protocol, int(mapping.InternalPort)); err != nil {
log.Warnf("delete port mapping on stop: %v", err)
return
}
log.Infof("deleted port mapping for port %d", mapping.InternalPort)
}
func (m *Manager) startTearDown(ctx context.Context) {
select {
case m.stopCtx <- ctx:
default:
}
}
// isPermanentLeaseRequired checks if a UPnP error indicates the gateway only supports permanent leases (error 725).
func isPermanentLeaseRequired(err error) bool {
return err != nil && upnpErrPermanentLeaseOnly.MatchString(err.Error())
}

View File

@@ -0,0 +1,39 @@
package portforward
import (
"context"
"net"
"time"
)
// Mapping represents an active NAT port mapping.
type Mapping struct {
Protocol string
InternalPort uint16
ExternalPort uint16
ExternalIP net.IP
NATType string
// TTL is the lease duration. Zero means a permanent lease that never expires.
TTL time.Duration
}
// Manager is a stub for js/wasm builds where NAT-PMP/UPnP is not supported.
type Manager struct{}
// NewManager returns a stub manager for js/wasm builds.
func NewManager() *Manager {
return &Manager{}
}
// Start is a no-op on js/wasm: NAT-PMP/UPnP is not available in browser environments.
func (m *Manager) Start(context.Context, uint16) {
// no NAT traversal in wasm
}
// GracefullyStop is a no-op on js/wasm.
func (m *Manager) GracefullyStop(context.Context) error { return nil }
// GetMapping always returns nil on js/wasm.
func (m *Manager) GetMapping() *Mapping {
return nil
}

View File

@@ -0,0 +1,201 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockNAT struct {
natType string
deviceAddr net.IP
externalAddr net.IP
internalAddr net.IP
mappings map[int]int
addMappingErr error
deleteMappingErr error
onlyPermanentLeases bool
lastTimeout time.Duration
}
func newMockNAT() *mockNAT {
return &mockNAT{
natType: "Mock-NAT",
deviceAddr: net.ParseIP("192.168.1.1"),
externalAddr: net.ParseIP("203.0.113.50"),
internalAddr: net.ParseIP("192.168.1.100"),
mappings: make(map[int]int),
}
}
func (m *mockNAT) Type() string {
return m.natType
}
func (m *mockNAT) GetDeviceAddress() (net.IP, error) {
return m.deviceAddr, nil
}
func (m *mockNAT) GetExternalAddress() (net.IP, error) {
return m.externalAddr, nil
}
func (m *mockNAT) GetInternalAddress() (net.IP, error) {
return m.internalAddr, nil
}
func (m *mockNAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, description string, timeout time.Duration) (int, error) {
if m.addMappingErr != nil {
return 0, m.addMappingErr
}
if m.onlyPermanentLeases && timeout != 0 {
return 0, fmt.Errorf("SOAP fault. Code: | Explanation: | Detail: <UPnPError xmlns=\"urn:schemas-upnp-org:control-1-0\"><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>")
}
externalPort := internalPort
m.mappings[internalPort] = externalPort
m.lastTimeout = timeout
return externalPort, nil
}
func (m *mockNAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if m.deleteMappingErr != nil {
return m.deleteMappingErr
}
delete(m.mappings, internalPort)
return nil
}
func TestManager_CreateMapping(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, "udp", mapping.Protocol)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, uint16(51820), mapping.ExternalPort)
assert.Equal(t, "Mock-NAT", mapping.NATType)
assert.Equal(t, net.ParseIP("203.0.113.50").To4(), mapping.ExternalIP.To4())
assert.Equal(t, defaultMappingTTL, mapping.TTL)
}
func TestManager_GetMapping_ReturnsNilWhenNotReady(t *testing.T) {
m := NewManager()
assert.Nil(t, m.GetMapping())
}
func TestManager_GetMapping_ReturnsCopy(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
mapping := m.GetMapping()
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
// Mutating the returned copy should not affect the manager's mapping.
mapping.ExternalPort = 9999
assert.Equal(t, uint16(51820), m.GetMapping().ExternalPort)
}
func TestManager_Cleanup_DeletesMapping(t *testing.T) {
m := NewManager()
m.mapping = &Mapping{
Protocol: "udp",
InternalPort: 51820,
ExternalPort: 51820,
}
gateway := newMockNAT()
// Seed the mock so we can verify deletion.
gateway.mappings[51820] = 51820
m.cleanup(context.Background(), gateway)
_, exists := gateway.mappings[51820]
assert.False(t, exists, "mapping should be deleted from gateway")
assert.Nil(t, m.GetMapping(), "in-memory mapping should be cleared")
}
func TestManager_Cleanup_NilMapping(t *testing.T) {
m := NewManager()
gateway := newMockNAT()
// Should not panic or call gateway.
m.cleanup(context.Background(), gateway)
}
func TestManager_CreateMapping_PermanentLeaseFallback(t *testing.T) {
m := NewManager()
m.wgPort = 51820
gateway := newMockNAT()
gateway.onlyPermanentLeases = true
mapping, err := m.createMapping(context.Background(), gateway)
require.NoError(t, err)
require.NotNil(t, mapping)
assert.Equal(t, uint16(51820), mapping.InternalPort)
assert.Equal(t, time.Duration(0), mapping.TTL, "should return zero TTL for permanent lease")
assert.Equal(t, time.Duration(0), gateway.lastTimeout, "should have retried with zero duration")
}
func TestIsPermanentLeaseRequired(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "nil error",
err: nil,
expected: false,
},
{
name: "UPnP error 725",
err: fmt.Errorf("SOAP fault. Code: | Detail: <UPnPError><errorCode>725</errorCode><errorDescription>OnlyPermanentLeasesSupported</errorDescription></UPnPError>"),
expected: true,
},
{
name: "wrapped error with 725",
err: fmt.Errorf("add port mapping: %w", fmt.Errorf("Detail: <errorCode>725</errorCode>")),
expected: true,
},
{
name: "error 725 with newlines in XML",
err: fmt.Errorf("<errorCode>\n 725\n</errorCode>"),
expected: true,
},
{
name: "bare 725 without XML tag",
err: fmt.Errorf("error code 725"),
expected: false,
},
{
name: "unrelated error",
err: fmt.Errorf("connection refused"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, isPermanentLeaseRequired(tt.err))
})
}
}

View File

@@ -41,7 +41,7 @@ const (
// mgmProber is the subset of management client needed for URL migration probes.
type mgmProber interface {
GetServerPublicKey() (*wgtypes.Key, error)
HealthCheck() error
Close() error
}
@@ -64,11 +64,13 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
@@ -114,11 +116,13 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
@@ -415,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.True()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
@@ -465,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
if *input.DisableVNCAuth {
log.Infof("disabling VNC authentication")
} else {
log.Infof("enabling VNC authentication")
}
config.DisableVNCAuth = input.DisableVNCAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
@@ -777,8 +806,7 @@ func UpdateOldManagementURL(ctx context.Context, config *Config, configPath stri
}()
// gRPC check
_, err = client.GetServerPublicKey()
if err != nil {
if err = client.HealthCheck(); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}

View File

@@ -17,12 +17,10 @@ import (
"github.com/netbirdio/netbird/util"
)
type mockMgmProber struct {
key wgtypes.Key
}
type mockMgmProber struct{}
func (m *mockMgmProber) GetServerPublicKey() (*wgtypes.Key, error) {
return &m.key, nil
func (m *mockMgmProber) HealthCheck() error {
return nil
}
func (m *mockMgmProber) Close() error { return nil }
@@ -247,11 +245,7 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {
key, err := wgtypes.GenerateKey()
if err != nil {
return nil, err
}
return &mockMgmProber{key: key.PublicKey()}, nil
return &mockMgmProber{}, nil
}
t.Cleanup(func() { newMgmProber = origProber })

View File

@@ -52,6 +52,7 @@ type Manager interface {
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -167,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
m.notifier.SetFakeIPRoute(fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
@@ -465,6 +467,16 @@ func (m *DefaultManager) GetClientRoutes() route.HAMap {
return maps.Clone(m.clientRoutes)
}
// GetSelectedClientRoutes returns only the currently selected/active client routes,
// filtering out deselected exit nodes. Use this instead of GetClientRoutes when checking
// if traffic should be routed through the tunnel.
func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
m.mux.Lock()
defer m.mux.Unlock()
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()

View File

@@ -18,6 +18,7 @@ type MockManager struct {
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -61,7 +62,7 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil
}
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
// GetClientRoutes mock implementation of GetClientRoutes from the Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
@@ -69,6 +70,14 @@ func (m *MockManager) GetClientRoutes() route.HAMap {
return nil
}
// GetSelectedClientRoutes mock implementation of GetSelectedClientRoutes from the Manager interface
func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
if m.GetSelectedClientRoutesFunc != nil {
return m.GetSelectedClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -16,6 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoute *route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
n.fakeIPRoute = r
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
}
allRoutes := slices.Clone(n.currentRoutes)
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
if n.fakeIPRoute != nil {
allRoutes = append(allRoutes, n.fakeIPRoute)
}
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {

View File

@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on iOS
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}
@@ -53,7 +57,6 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()

View File

@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}

View File

@@ -161,7 +161,11 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
}
// Stop the internal client and free the resources

File diff suppressed because it is too large Load Diff

View File

@@ -209,6 +209,9 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverVNCAllowed = 41;
optional bool disableVNCAuth = 42;
}
message LoginResponse {
@@ -316,6 +319,10 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverVNCAllowed = 28;
bool disableVNCAuth = 29;
}
// PeerState contains the latest state of a peer
@@ -394,6 +401,11 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -408,6 +420,7 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -677,6 +690,9 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverVNCAllowed = 36;
optional bool disableVNCAuth = 37;
}
message SetConfigResponse{}

View File

@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -382,6 +383,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.DisableVNCAuth != nil {
config.DisableVNCAuth = msg.DisableVNCAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
@@ -1120,6 +1124,7 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1159,6 +1164,26 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
return &proto.VNCServerState{
Enabled: engine.GetVNCServerStatus(),
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1359,6 +1384,10 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return gstatus.Errorf(codes.FailedPrecondition, "engine not initialized")
}
if engine.IsBlockInbound() {
return gstatus.Errorf(codes.FailedPrecondition, "expose requires inbound connections but 'block inbound' is enabled, disable it first")
}
mgr := engine.GetExposeManager()
if mgr == nil {
return gstatus.Errorf(codes.Internal, "expose manager not available")
@@ -1496,6 +1525,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableSSHAuth = *cfg.DisableSSHAuth
}
disableVNCAuth := false
if cfg.DisableVNCAuth != nil {
disableVNCAuth = *cfg.DisableVNCAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
@@ -1510,6 +1544,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
@@ -1525,6 +1560,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
DisableVNCAuth: disableVNCAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}

View File

@@ -58,6 +58,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCAuth := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -82,6 +84,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCAuth: &disableVNCAuth,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -125,6 +129,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCAuth)
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -176,6 +184,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCAuth": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -236,6 +246,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-auth": "DisableVNCAuth",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/config"
)
// registerStates registers all states that need crash recovery cleanup.
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})

View File

@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
hasCommand := session.RawCommand() != ""
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
}
if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
if err := serverSession.Run(session.RawCommand()); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}

View File

@@ -1,6 +1,7 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte

View File

@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
}
}
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
// generateS4UUserToken creates a Windows token using S4U authentication.
// This is the same approach OpenSSH for Windows uses for public key authentication.
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)

View File

@@ -284,19 +284,21 @@ func (s *Server) closeListener(ln net.Listener) {
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer == nil {
sshServer := s.sshServer
if sshServer == nil {
s.mu.Unlock()
return nil
}
s.sshServer = nil
s.listener = nil
s.mu.Unlock()
if err := s.sshServer.Close(); err != nil {
// Close outside the lock: session handlers need s.mu for unregisterSession.
if err := sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
s.sshServer = nil
s.listener = nil
s.mu.Lock()
maps.Clear(s.sessions)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.connections)
@@ -307,6 +309,7 @@ func (s *Server) Stop() error {
}
}
maps.Clear(s.remoteForwardListeners)
s.mu.Unlock()
return nil
}
@@ -504,27 +507,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
maxTokenAge = DefaultJWTMaxTokenAge
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
@@ -555,27 +538,7 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
return jwt.UserIDFromToken(token)
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {

View File

@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
}
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
hasCommand := session.RawCommand() != ""
if isPty && !hasCommand {
// ssh <host> - PTY interactive session (login)

View File

@@ -130,6 +130,10 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -151,6 +155,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -171,6 +176,9 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := VNCServerStateOutput{
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
}
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -194,6 +202,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -524,6 +533,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncServerStatus = "Enabled"
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -553,6 +567,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -570,6 +585,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,

View File

@@ -398,6 +398,9 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false
}
}`
// @formatter:on
@@ -505,6 +508,8 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
`
assert.Equal(t, expectedYAML, yaml)
@@ -572,6 +577,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -596,6 +602,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -63,6 +63,7 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -78,21 +79,27 @@ type Info struct {
EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
DisableVNCAuth bool
}
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
disableVNCAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes
@@ -118,6 +125,9 @@ func (i *Info) SetFlags(
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
if disableVNCAuth != nil {
i.DisableVNCAuth = *disableVNCAuth
}
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@@ -153,6 +163,9 @@ func networkAddresses() ([]NetworkAddress, error) {
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}

View File

@@ -43,18 +43,24 @@ func GetInfo(ctx context.Context) *Info {
systemHostname, _ := os.Hostname()
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
return &Info{
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
Environment: env,
GoOS: runtime.GOOS,
Kernel: osInfo[0],
Platform: runtime.GOARCH,
OS: osName,
OSVersion: osVersion,
Hostname: extractDeviceName(ctx, systemHostname),
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
NetworkAddresses: addrs,
Environment: env,
}
}

View File

@@ -24,9 +24,10 @@ import (
// Initial state for the debug collection
type debugInitialState struct {
wasDown bool
logLevel proto.LogLevel
isLevelTrace bool
wasDown bool
needsRestoreUp bool
logLevel proto.LogLevel
isLevelTrace bool
}
// Debug collection parameters
@@ -371,46 +372,51 @@ func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient,
state *debugInitialState,
enablePersistence bool,
) error {
) {
if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service up: %v", err)
log.Warnf("failed to bring service up: %v", err)
} else {
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
log.Info("Service brought up for debug")
time.Sleep(time.Second * 10)
}
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: proto.LogLevel_TRACE}); err != nil {
return fmt.Errorf("set log level to TRACE: %v", err)
log.Warnf("failed to set log level to TRACE: %v", err)
} else {
log.Info("Log level set to TRACE for debug")
}
log.Info("Log level set to TRACE for debug")
}
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
return fmt.Errorf("bring service down: %v", err)
log.Warnf("failed to bring service down: %v", err)
} else {
state.needsRestoreUp = !state.wasDown
time.Sleep(time.Second)
}
time.Sleep(time.Second)
if enablePersistence {
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
Enabled: true,
}); err != nil {
return fmt.Errorf("enable sync response persistence: %v", err)
log.Warnf("failed to enable sync response persistence: %v", err)
} else {
log.Info("Sync response persistence enabled for debug")
}
log.Info("Sync response persistence enabled for debug")
}
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("bring service back up: %v", err)
log.Warnf("failed to bring service back up: %v", err)
} else {
state.needsRestoreUp = false
time.Sleep(time.Second * 3)
}
time.Sleep(time.Second * 3)
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err)
}
return nil
}
func (s *serviceClient) collectDebugData(
@@ -424,9 +430,7 @@ func (s *serviceClient) collectDebugData(
var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress)
if err := s.configureServiceForDebug(conn, state, params.enablePersistence); err != nil {
return err
}
s.configureServiceForDebug(conn, state, params.enablePersistence)
wg.Wait()
progress.progressBar.Hide()
@@ -482,9 +486,17 @@ func (s *serviceClient) createDebugBundleFromCollection(
// Restore service to original state
func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, state *debugInitialState) {
if state.needsRestoreUp {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
log.Warnf("failed to restore up state: %v", err)
} else {
log.Info("Service state restored to up")
}
}
if state.wasDown {
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("Failed to restore down state: %v", err)
log.Warnf("failed to restore down state: %v", err)
} else {
log.Info("Service state restored to down")
}
@@ -492,7 +504,7 @@ func (s *serviceClient) restoreServiceState(conn proto.DaemonServiceClient, stat
if !state.isLevelTrace {
if _, err := conn.SetLogLevel(s.ctx, &proto.SetLogLevelRequest{Level: state.logLevel}); err != nil {
log.Errorf("Failed to restore log level: %v", err)
log.Warnf("failed to restore log level: %v", err)
} else {
log.Info("Log level restored to original setting")
}

View File

@@ -0,0 +1,474 @@
//go:build windows
package server
import (
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
agentPort = "15900"
// agentTokenLen is the length of the random authentication token
// used to verify that connections to the agent come from the service.
agentTokenLen = 32
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
type wtsSessionInfo struct {
SessionID uint32
WinStationName [66]byte // actually *uint16, but we just need the struct size
State uint32
}
// getActiveSessionID returns the session ID of the best session to attach to.
// Prefers an active (logged-in, interactive) session over the console session.
// This avoids kicking out an RDP user when the console is at the login screen.
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer procWTSFreeMemory.Call(sessionInfo)
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
// Find the first active session (not session 0, which is the services session).
var bestID uint32
found := false
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State == wtsActive {
bestID = s.SessionID
found = true
break
}
}
if !found {
return getConsoleSessionID()
}
return bestID
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns a new block pointer with the entry added.
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return uintptr(unsafe.Pointer(&newBlock[0]))
}
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, _ := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r != 0 {
defer procDestroyEnvironmentBlock.Call(envBlock)
}
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic).
if r != 0 {
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
}
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow,
envPtr, nil, &si, &pi,
)
// Close the write end in the parent so reads will get EOF when the child exits.
windows.CloseHandle(stderrWrite)
if err != nil {
windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one.
type sessionManager struct {
port string
mu sync.Mutex
agentProc windows.Handle
sessionID uint32
authToken string
done chan struct{}
}
func newSessionManager(port string) *sessionManager {
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
}
// generateAuthToken creates a new random hex token for agent authentication.
func generateAuthToken() string {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
log.Warnf("generate agent auth token: %v", err)
return ""
}
return hex.EncodeToString(b)
}
// AuthToken returns the current agent authentication token.
func (m *sessionManager) AuthToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.authToken
}
// Stop signals the session manager to exit its polling loop.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
sid := getActiveSessionID()
m.mu.Lock()
if sid != m.sessionID {
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
if m.agentProc != 0 {
var code uint32
_ = windows.GetExitCodeProcess(m.agentProc, &code)
if code != stillActive {
log.Infof("agent exited (code=%d), respawning", code)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
}
}
if m.agentProc == 0 && sid != 0xFFFFFFFF {
m.authToken = generateAuthToken()
h, err := spawnAgentInSession(sid, m.port, m.authToken)
if err != nil {
log.Warnf("spawn agent in session %d: %v", sid, err)
m.authToken = ""
} else {
m.agentProc = h
}
}
m.mu.Unlock()
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
func (m *sessionManager) killAgent() {
if m.agentProc != 0 {
_ = windows.TerminateProcess(m.agentProc, 0)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
log.Info("killed old agent")
}
}
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
// relogs them at the correct level with the service's formatter.
func relogAgentOutput(pipe windows.Handle) {
defer windows.CloseHandle(pipe)
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer f.Close()
entry := log.WithField("component", "vnc-agent")
dec := json.NewDecoder(f)
for dec.More() {
var m map[string]any
if err := dec.Decode(&m); err != nil {
break
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
// Forward extra fields from the agent (skip standard logrus fields).
// Remap "caller" to "source" so it doesn't conflict with logrus internals
// but still shows the original file/line from the agent process.
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// proxyToAgent connects to the agent, sends the auth token, then proxies
// the VNC client connection bidirectionally.
func proxyToAgent(client net.Conn, port string, authToken string) {
defer client.Close()
addr := "127.0.0.1:" + port
var agentConn net.Conn
var err error
for range 50 {
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
if err == nil {
break
}
time.Sleep(200 * time.Millisecond)
}
if err != nil {
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
return
}
defer agentConn.Close()
// Send the auth token so the agent can verify this connection
// comes from the trusted service process.
tokenBytes, _ := hex.DecodeString(authToken)
if _, err := agentConn.Write(tokenBytes); err != nil {
log.Warnf("send auth token to agent: %v", err)
return
}
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client→agent", agentConn, client)
go cp("agent→client", client, agentConn)
<-done
}

View File

@@ -0,0 +1,274 @@
//go:build darwin && !ios
package server
import (
"fmt"
"image"
"sync"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var darwinCaptureOnce sync.Once
var (
cgMainDisplayID func() uint32
cgDisplayPixelsWide func(uint32) uintptr
cgDisplayPixelsHigh func(uint32) uintptr
cgDisplayCreateImage func(uint32) uintptr
cgImageGetWidth func(uintptr) uintptr
cgImageGetHeight func(uintptr) uintptr
cgImageGetBytesPerRow func(uintptr) uintptr
cgImageGetBitsPerPixel func(uintptr) uintptr
cgImageGetDataProvider func(uintptr) uintptr
cgDataProviderCopyData func(uintptr) uintptr
cgImageRelease func(uintptr)
cfDataGetLength func(uintptr) int64
cfDataGetBytePtr func(uintptr) uintptr
cfRelease func(uintptr)
cgPreflightScreenCaptureAccess func() bool
cgRequestScreenCaptureAccess func() bool
darwinCaptureReady bool
)
func initDarwinCapture() {
darwinCaptureOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation: %v", err)
return
}
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
}
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
}
darwinCaptureReady = true
})
}
// CGCapturer captures the macOS main display using Core Graphics.
type CGCapturer struct {
displayID uint32
w, h int
}
// NewCGCapturer creates a screen capturer for the main display.
func NewCGCapturer() (*CGCapturer, error) {
initDarwinCapture()
if !darwinCaptureReady {
return nil, fmt.Errorf("CoreGraphics not available")
}
// Request Screen Recording permission (shows system dialog on macOS 11+).
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
log.Warn("Screen Recording permission not granted. " +
"Grant in System Settings > Privacy & Security > Screen Recording, then restart.")
}
displayID := cgMainDisplayID()
w := int(cgDisplayPixelsWide(displayID))
h := int(cgDisplayPixelsHigh(displayID))
if w == 0 || h == 0 {
return nil, fmt.Errorf("display dimensions are zero")
}
log.Infof("macOS capturer ready: %dx%d (display=%d)", w, h, displayID)
return &CGCapturer{displayID: displayID, w: w, h: h}, nil
}
// Width returns the screen width.
func (c *CGCapturer) Width() int { return c.w }
// Height returns the screen height.
func (c *CGCapturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *CGCapturer) Capture() (*image.RGBA, error) {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return nil, fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
img := image.NewRGBA(image.Rect(0, 0, w, h))
bytesPerPixel := bpp / 8
for row := 0; row < h; row++ {
srcOff := row * bytesPerRow
dstOff := row * img.Stride
for col := 0; col < w; col++ {
si := srcOff + col*bytesPerPixel
di := dstOff + col*4
img.Pix[di+0] = src[si+2] // R (from BGRA)
img.Pix[di+1] = src[si+1] // G
img.Pix[di+2] = src[si+0] // B
img.Pix[di+3] = 0xff
}
}
return img, nil
}
// MacPoller wraps CGCapturer in a continuous capture loop.
type MacPoller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
done chan struct{}
}
// NewMacPoller creates a capturer that continuously grabs the macOS display.
func NewMacPoller() *MacPoller {
p := &MacPoller{done: make(chan struct{})}
go p.loop()
return p
}
// Close stops the capture loop.
func (p *MacPoller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *MacPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *MacPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *MacPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *MacPoller) loop() {
var capturer *CGCapturer
var initFails int
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewCGCapturer()
if err != nil {
initFails++
if initFails <= maxCapturerRetries {
log.Debugf("macOS capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
select {
case <-p.done:
return
case <-time.After(2 * time.Second):
}
continue
}
log.Warnf("macOS capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
return
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if err != nil {
log.Debugf("macOS capture: %v", err)
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}
var _ ScreenCapturer = (*MacPoller)(nil)

View File

@@ -0,0 +1,99 @@
//go:build windows
package server
import (
"errors"
"fmt"
"image"
"github.com/kirides/go-d3d/d3d11"
"github.com/kirides/go-d3d/outputduplication"
)
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
// Provides GPU-accelerated capture with native dirty rect tracking.
// Only works from the interactive user session, not Session 0.
//
// Uses a double-buffer: DXGI writes into img, then we copy to the current
// output buffer and hand it out. Alternating between two output buffers
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
type dxgiCapturer struct {
dup *outputduplication.OutputDuplicator
device *d3d11.ID3D11Device
ctx *d3d11.ID3D11DeviceContext
img *image.RGBA
out [2]*image.RGBA
outIdx int
width int
height int
}
func newDXGICapturer() (*dxgiCapturer, error) {
device, deviceCtx, err := d3d11.NewD3D11Device()
if err != nil {
return nil, fmt.Errorf("create D3D11 device: %w", err)
}
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
if err != nil {
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("create output duplication: %w", err)
}
w, h := screenSize()
if w == 0 || h == 0 {
dup.Release()
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("screen dimensions are zero")
}
rect := image.Rect(0, 0, w, h)
c := &dxgiCapturer{
dup: dup,
device: device,
ctx: deviceCtx,
img: image.NewRGBA(rect),
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
width: w,
height: h,
}
// Grab the initial frame with a longer timeout to ensure we have
// a valid image before returning.
_ = dup.GetImage(c.img, 2000)
return c, nil
}
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
err := c.dup.GetImage(c.img, 100)
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
return nil, err
}
// Copy into the next output buffer. The DesktopCapturer hands out the
// returned pointer to VNC sessions that read pixels concurrently, so we
// alternate between two pre-allocated buffers instead of allocating per frame.
out := c.out[c.outIdx]
c.outIdx ^= 1
copy(out.Pix, c.img.Pix)
return out, nil
}
func (c *dxgiCapturer) close() {
if c.dup != nil {
c.dup.Release()
c.dup = nil
}
if c.ctx != nil {
c.ctx.Release()
c.ctx = nil
}
if c.device != nil {
c.device.Release()
c.device = nil
}
}

View File

@@ -0,0 +1,461 @@
//go:build windows
package server
import (
"fmt"
"image"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
user32 = windows.NewLazySystemDLL("user32.dll")
procGetDC = user32.NewProc("GetDC")
procReleaseDC = user32.NewProc("ReleaseDC")
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
procSelectObject = gdi32.NewProc("SelectObject")
procDeleteObject = gdi32.NewProc("DeleteObject")
procDeleteDC = gdi32.NewProc("DeleteDC")
procBitBlt = gdi32.NewProc("BitBlt")
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
// Desktop switching for service/Session 0 capture.
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
procCloseDesktop = user32.NewProc("CloseDesktop")
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
procCloseWindowStation = user32.NewProc("CloseWindowStation")
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
)
const uoiName = 2
const (
smCxScreen = 0
smCyScreen = 1
srccopy = 0x00CC0020
dibRgbColors = 0
)
type bitmapInfoHeader struct {
Size uint32
Width int32
Height int32
Planes uint16
BitCount uint16
Compression uint32
SizeImage uint32
XPelsPerMeter int32
YPelsPerMeter int32
ClrUsed uint32
ClrImportant uint32
}
type bitmapInfo struct {
Header bitmapInfoHeader
}
// setupInteractiveWindowStation associates the current process with WinSta0,
// the interactive window station. This is required for a SYSTEM service in
// Session 0 to call OpenInputDesktop for screen capture and input injection.
func setupInteractiveWindowStation() error {
name, err := windows.UTF16PtrFromString("WinSta0")
if err != nil {
return fmt.Errorf("UTF16 WinSta0: %w", err)
}
hWinSta, _, err := procOpenWindowStation.Call(
uintptr(unsafe.Pointer(name)),
0,
uintptr(windows.MAXIMUM_ALLOWED),
)
if hWinSta == 0 {
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
}
r, _, err := procSetProcessWindowStation.Call(hWinSta)
if r == 0 {
procCloseWindowStation.Call(hWinSta)
return fmt.Errorf("SetProcessWindowStation: %w", err)
}
log.Info("process window station set to WinSta0 (interactive)")
return nil
}
func screenSize() (int, int) {
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
return int(w), int(h)
}
func getDesktopName(hDesk uintptr) string {
var buf [256]uint16
var needed uint32
procGetUserObjectInformationW.Call(hDesk, uoiName,
uintptr(unsafe.Pointer(&buf[0])), 512,
uintptr(unsafe.Pointer(&needed)))
return windows.UTF16ToString(buf[:])
}
// switchToInputDesktop opens the desktop currently receiving user input
// and sets it as the calling OS thread's desktop. Must be called from a
// goroutine locked to its OS thread via runtime.LockOSThread().
func switchToInputDesktop() (bool, string) {
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
if hDesk == 0 {
return false, ""
}
name := getDesktopName(hDesk)
ret, _, _ := procSetThreadDesktop.Call(hDesk)
procCloseDesktop.Call(hDesk)
return ret != 0, name
}
// gdiCapturer captures the desktop screen using GDI BitBlt.
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
type gdiCapturer struct {
mu sync.Mutex
width int
height int
// Pre-allocated GDI resources, reused across captures.
memDC uintptr
bmp uintptr
bits uintptr
}
func newGDICapturer() (*gdiCapturer, error) {
w, h := screenSize()
if w == 0 || h == 0 {
return nil, fmt.Errorf("screen dimensions are zero")
}
c := &gdiCapturer{width: w, height: h}
if err := c.allocGDI(); err != nil {
return nil, err
}
return c, nil
}
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
func (c *gdiCapturer) allocGDI() error {
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
memDC, _, _ := procCreateCompatDC.Call(screenDC)
if memDC == 0 {
return fmt.Errorf("CreateCompatibleDC returned 0")
}
bi := bitmapInfo{
Header: bitmapInfoHeader{
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
Width: int32(c.width),
Height: -int32(c.height), // negative = top-down DIB
Planes: 1,
BitCount: 32,
},
}
var bits uintptr
bmp, _, _ := procCreateDIBSection.Call(
screenDC,
uintptr(unsafe.Pointer(&bi)),
dibRgbColors,
uintptr(unsafe.Pointer(&bits)),
0, 0,
)
if bmp == 0 || bits == 0 {
procDeleteDC.Call(memDC)
return fmt.Errorf("CreateDIBSection returned 0")
}
procSelectObject.Call(memDC, bmp)
c.memDC = memDC
c.bmp = bmp
c.bits = bits
return nil
}
func (c *gdiCapturer) close() { c.freeGDI() }
// freeGDI releases pre-allocated GDI resources.
func (c *gdiCapturer) freeGDI() {
if c.bmp != 0 {
procDeleteObject.Call(c.bmp)
c.bmp = 0
}
if c.memDC != 0 {
procDeleteDC.Call(c.memDC)
c.memDC = 0
}
c.bits = 0
}
func (c *gdiCapturer) capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.memDC == 0 {
return nil, fmt.Errorf("GDI resources not allocated")
}
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return nil, fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
screenDC, 0, 0, srccopy)
if ret == 0 {
return nil, fmt.Errorf("BitBlt returned 0")
}
n := c.width * c.height * 4
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
// Swap R and B in bulk using uint32 operations (one load + mask + shift
// per pixel instead of three separate byte assignments).
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
pix := img.Pix
copy(pix, raw)
swizzleBGRAtoRGBA(pix)
return img, nil
}
// DesktopCapturer captures the interactive desktop, handling desktop transitions
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
// captures frames, which are retrieved by the VNC session on demand.
// Capture pauses automatically when no clients are connected.
type DesktopCapturer struct {
mu sync.Mutex
frame *image.RGBA
w, h int
// clients tracks the number of active VNC sessions. When zero, the
// capture loop idles instead of grabbing frames.
clients atomic.Int32
// wake is signaled when a client connects and the loop should resume.
wake chan struct{}
// done is closed when Close is called, terminating the capture loop.
done chan struct{}
}
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
func NewDesktopCapturer() *DesktopCapturer {
c := &DesktopCapturer{
wake: make(chan struct{}, 1),
done: make(chan struct{}),
}
go c.loop()
return c
}
// ClientConnect increments the active client count, resuming capture if needed.
func (c *DesktopCapturer) ClientConnect() {
c.clients.Add(1)
select {
case c.wake <- struct{}{}:
default:
}
}
// ClientDisconnect decrements the active client count.
func (c *DesktopCapturer) ClientDisconnect() {
c.clients.Add(-1)
}
// Close stops the capture loop and releases resources.
func (c *DesktopCapturer) Close() {
select {
case <-c.done:
default:
close(c.done)
}
}
// Width returns the current screen width.
func (c *DesktopCapturer) Width() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.w
}
// Height returns the current screen height.
func (c *DesktopCapturer) Height() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.h
}
// Capture returns the most recent desktop frame.
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
img := c.frame
c.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
// waitForClient blocks until a client connects or the capturer is closed.
func (c *DesktopCapturer) waitForClient() bool {
if c.clients.Load() > 0 {
return true
}
select {
case <-c.wake:
return true
case <-c.done:
return false
}
}
func (c *DesktopCapturer) loop() {
runtime.LockOSThread()
// When running as a Windows service (Session 0), we need to attach to the
// interactive window station before OpenInputDesktop will succeed.
if err := setupInteractiveWindowStation(); err != nil {
log.Warnf("attach to interactive window station: %v", err)
}
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
defer frameTicker.Stop()
retryTimer := time.NewTimer(0)
retryTimer.Stop()
defer retryTimer.Stop()
type frameCapturer interface {
capture() (*image.RGBA, error)
close()
}
var cap frameCapturer
var desktopFails int
var lastDesktop string
createCapturer := func() (frameCapturer, error) {
dc, err := newDXGICapturer()
if err == nil {
log.Info("using DXGI Desktop Duplication for capture")
return dc, nil
}
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
gc, err := newGDICapturer()
if err != nil {
return nil, err
}
log.Info("using GDI BitBlt for capture")
return gc, nil
}
for {
if !c.waitForClient() {
if cap != nil {
cap.close()
}
return
}
// No clients: release the capturer and wait.
if c.clients.Load() <= 0 {
if cap != nil {
cap.close()
cap = nil
}
continue
}
ok, desk := switchToInputDesktop()
if !ok {
desktopFails++
if desktopFails == 1 || desktopFails%100 == 0 {
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
}
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
if desktopFails > 0 {
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
desktopFails = 0
}
if desk != lastDesktop {
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
lastDesktop = desk
if cap != nil {
cap.close()
}
cap = nil
}
if cap == nil {
fc, err := createCapturer()
if err != nil {
log.Warnf("create capturer: %v", err)
retryTimer.Reset(500 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
cap = fc
w, h := screenSize()
c.mu.Lock()
c.w, c.h = w, h
c.mu.Unlock()
log.Infof("screen capturer ready: %dx%d", w, h)
}
img, err := cap.capture()
if err != nil {
log.Debugf("capture: %v", err)
cap.close()
cap = nil
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
c.mu.Lock()
c.frame = img
c.mu.Unlock()
select {
case <-frameTicker.C:
case <-c.done:
if cap != nil {
cap.close()
}
return
}
}
}

View File

@@ -0,0 +1,385 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"image"
"os"
"os/exec"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
)
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
type X11Capturer struct {
mu sync.Mutex
conn *xgb.Conn
screen *xproto.ScreenInfo
w, h int
shmID int
shmAddr []byte
shmSeg uint32 // shm.Seg
useSHM bool
}
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
// environment variables if needed. This is required when running as a system
// service where these vars aren't set.
func detectX11Display() {
if os.Getenv("DISPLAY") != "" {
return
}
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
if detectX11FromProc() {
return
}
if detectX11FromSockets() {
return
}
}
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
func detectX11FromProc() bool {
entries, err := os.ReadDir("/proc")
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() {
continue
}
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
if err != nil {
continue
}
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
setDisplayEnv(display, auth)
return true
}
}
return false
}
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
// to find the auth file. Works on FreeBSD and other systems without /proc.
func detectX11FromSockets() bool {
entries, err := os.ReadDir("/tmp/.X11-unix")
if err != nil {
return false
}
// Find the lowest display number.
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
display := ":" + name[1:]
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
// Try to find -auth from ps output.
if auth := findXorgAuthFromPS(); auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
}
return true
}
return false
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
func findXorgAuthFromPS() string {
out, err := exec.Command("ps", "auxww").Output()
if err != nil {
return ""
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
continue
}
fields := strings.Fields(line)
for i, f := range fields {
if f == "-auth" && i+1 < len(fields) {
return fields[i+1]
}
}
}
return ""
}
func parseXorgArgs(args []string) (display, auth string) {
if len(args) == 0 {
return "", ""
}
base := args[0]
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
return "", ""
}
for i, arg := range args[1:] {
if len(arg) > 0 && arg[0] == ':' {
display = arg
}
if arg == "-auth" && i+2 < len(args) {
auth = args[i+2]
}
}
return display, auth
}
func setDisplayEnv(display, auth string) {
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s", display)
if auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s", auth)
}
}
func splitCmdline(data []byte) []string {
var args []string
for _, b := range splitNull(data) {
if len(b) > 0 {
args = append(args, string(b))
}
}
return args
}
func splitNull(data []byte) [][]byte {
var parts [][]byte
start := 0
for i, b := range data {
if b == 0 {
parts = append(parts, data[start:i])
start = i + 1
}
}
if start < len(data) {
parts = append(parts, data[start:])
}
return parts
}
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
func NewX11Capturer(display string) (*X11Capturer, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
c := &X11Capturer{
conn: conn,
screen: &screen,
w: int(screen.WidthInPixels),
h: int(screen.HeightInPixels),
}
if err := c.initSHM(); err != nil {
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
}
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
return c, nil
}
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
// the capturer falls back to GetImage.
// Width returns the screen width.
func (c *X11Capturer) Width() int { return c.w }
// Height returns the screen height.
func (c *X11Capturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *X11Capturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.useSHM {
return c.captureSHM()
}
return c.captureGetImage()
}
// captureSHM is implemented in capture_x11_shm_linux.go.
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
data := reply.Data
n := c.w * c.h * 4
if len(data) < n {
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
}
for i := 0; i < n; i += 4 {
img.Pix[i+0] = data[i+2] // R
img.Pix[i+1] = data[i+1] // G
img.Pix[i+2] = data[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
// Close releases X11 resources.
func (c *X11Capturer) Close() {
c.closeSHM()
c.conn.Close()
}
// closeSHM is implemented in capture_x11_shm_linux.go.
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
// DesktopCapturer pattern from Windows.
type X11Poller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
display string
done chan struct{}
}
// NewX11Poller creates a capturer that continuously grabs the X11 display.
func NewX11Poller(display string) *X11Poller {
p := &X11Poller{
display: display,
done: make(chan struct{}),
}
go p.loop()
return p
}
// Close stops the capture loop.
func (p *X11Poller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *X11Poller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *X11Poller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *X11Poller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *X11Poller) loop() {
var capturer *X11Capturer
var initFails int
defer func() {
if capturer != nil {
capturer.Close()
}
}()
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewX11Capturer(p.display)
if err != nil {
initFails++
if initFails <= maxCapturerRetries {
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
select {
case <-p.done:
return
case <-time.After(2 * time.Second):
}
continue
}
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
return
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if err != nil {
log.Debugf("X11 capture: %v", err)
capturer.Close()
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}

View File

@@ -0,0 +1,78 @@
//go:build linux && !android
package server
import (
"fmt"
"image"
"github.com/jezek/xgb/shm"
"github.com/jezek/xgb/xproto"
"golang.org/x/sys/unix"
)
func (c *X11Capturer) initSHM() error {
if err := shm.Init(c.conn); err != nil {
return fmt.Errorf("init SHM extension: %w", err)
}
size := c.w * c.h * 4
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
if err != nil {
return fmt.Errorf("shmget: %w", err)
}
addr, err := unix.SysvShmAttach(id, 0, 0)
if err != nil {
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
return fmt.Errorf("shmat: %w", err)
}
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
seg, err := shm.NewSegId(c.conn)
if err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("new SHM seg: %w", err)
}
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("SHM attach to X: %w", err)
}
c.shmID = id
c.shmAddr = addr
c.shmSeg = uint32(seg)
c.useSHM = true
return nil
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
_, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("SHM GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
n := c.w * c.h * 4
for i := 0; i < n; i += 4 {
img.Pix[i+0] = c.shmAddr[i+2] // R
img.Pix[i+1] = c.shmAddr[i+1] // G
img.Pix[i+2] = c.shmAddr[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
func (c *X11Capturer) closeSHM() {
if c.useSHM {
shm.Detach(c.conn, shm.Seg(c.shmSeg))
unix.SysvShmDetach(c.shmAddr)
}
}

View File

@@ -0,0 +1,18 @@
//go:build freebsd
package server
import (
"fmt"
"image"
)
func (c *X11Capturer) initSHM() error {
return fmt.Errorf("SysV SHM not available on this platform")
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
return nil, fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) closeSHM() {}

View File

@@ -0,0 +1,403 @@
//go:build darwin && !ios
package server
import (
"fmt"
"os/exec"
"strings"
"sync"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
// Core Graphics event constants.
const (
kCGEventSourceStateCombinedSessionState int32 = 0
kCGEventLeftMouseDown int32 = 1
kCGEventLeftMouseUp int32 = 2
kCGEventRightMouseDown int32 = 3
kCGEventRightMouseUp int32 = 4
kCGEventMouseMoved int32 = 5
kCGEventLeftMouseDragged int32 = 6
kCGEventRightMouseDragged int32 = 7
kCGEventKeyDown int32 = 10
kCGEventKeyUp int32 = 11
kCGEventOtherMouseDown int32 = 25
kCGEventOtherMouseUp int32 = 26
kCGMouseButtonLeft int32 = 0
kCGMouseButtonRight int32 = 1
kCGMouseButtonCenter int32 = 2
kCGHIDEventTap int32 = 0
)
var darwinInputOnce sync.Once
var (
cgEventSourceCreate func(int32) uintptr
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
// purego can't handle array/struct types but individual float64s work.
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
cgEventPost func(int32, uintptr)
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
cgEventCreateScrollWheelEventAddr uintptr
darwinInputReady bool
darwinEventSource uintptr
)
func initDarwinInput() {
darwinInputOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics for input: %v", err)
return
}
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
if err == nil {
cgEventCreateScrollWheelEventAddr = sym
}
darwinInputReady = true
})
}
func ensureEventSource() uintptr {
if darwinEventSource != 0 {
return darwinEventSource
}
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
return darwinEventSource
}
// MacInputInjector injects keyboard and mouse events via Core Graphics.
type MacInputInjector struct {
lastButtons uint8
pbcopyPath string
pbpastePath string
}
// NewMacInputInjector creates a macOS input injector.
func NewMacInputInjector() (*MacInputInjector, error) {
initDarwinInput()
if !darwinInputReady {
return nil, fmt.Errorf("CoreGraphics not available for input injection")
}
checkMacPermissions()
m := &MacInputInjector{}
if path, err := exec.LookPath("pbcopy"); err == nil {
m.pbcopyPath = path
}
if path, err := exec.LookPath("pbpaste"); err == nil {
m.pbpastePath = path
}
if m.pbcopyPath == "" || m.pbpastePath == "" {
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
}
log.Info("macOS input injector ready")
return m, nil
}
// checkMacPermissions logs warnings and triggers the Accessibility prompt.
// Screen Recording has no programmatic prompt, the user must grant it manually.
func checkMacPermissions() {
// Check Accessibility via osascript (triggers the system prompt dialog).
out, err := exec.Command("osascript", "-e",
`tell application "System Events" to return name of first process`).CombinedOutput()
if err != nil {
log.Warn("Accessibility permission not granted. Input injection will not work. " +
"Grant in System Settings > Privacy & Security > Accessibility.")
log.Debugf("accessibility check output: %s (%v)", strings.TrimSpace(string(out)), err)
}
log.Info("Screen Recording permission is required for screen capture. " +
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
}
// InjectKey simulates a key press or release.
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
src := ensureEventSource()
if src == 0 {
return
}
keycode := keysymToMacKeycode(keysym)
if keycode == 0xFFFF {
return
}
event := cgEventCreateKeyboardEvent(src, keycode, down)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
// InjectPointer simulates mouse movement and button events.
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
src := ensureEventSource()
if src == 0 {
return
}
x := float64(px)
y := float64(py)
leftDown := buttonMask&0x01 != 0
rightDown := buttonMask&0x04 != 0
middleDown := buttonMask&0x02 != 0
scrollUp := buttonMask&0x08 != 0
scrollDown := buttonMask&0x10 != 0
wasLeft := m.lastButtons&0x01 != 0
wasRight := m.lastButtons&0x04 != 0
wasMiddle := m.lastButtons&0x02 != 0
if leftDown {
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
} else if rightDown {
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
} else {
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
}
if leftDown && !wasLeft {
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
} else if !leftDown && wasLeft {
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
}
if rightDown && !wasRight {
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
} else if !rightDown && wasRight {
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
}
if middleDown && !wasMiddle {
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
} else if !middleDown && wasMiddle {
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
}
if scrollUp {
m.postScroll(src, 3)
}
if scrollDown {
m.postScroll(src, -3)
}
m.lastButtons = buttonMask
}
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
if cgEventCreateMouseEvent == nil {
return
}
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
if cgEventCreateScrollWheelEventAddr == 0 {
return
}
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
// Variadic C function: pass args as uintptr via SyscallN.
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
src, 0, 1, uintptr(uint32(deltaY)))
if r1 == 0 {
return
}
cgEventPost(kCGHIDEventTap, r1)
cfRelease(r1)
}
// SetClipboard sets the macOS clipboard using pbcopy.
func (m *MacInputInjector) SetClipboard(text string) {
if m.pbcopyPath == "" {
return
}
cmd := exec.Command(m.pbcopyPath)
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Tracef("set clipboard via pbcopy: %v", err)
}
}
// GetClipboard reads the macOS clipboard using pbpaste.
func (m *MacInputInjector) GetClipboard() string {
if m.pbpastePath == "" {
return ""
}
out, err := exec.Command(m.pbpastePath).Output()
if err != nil {
log.Tracef("get clipboard via pbpaste: %v", err)
return ""
}
return string(out)
}
// Close is a no-op on macOS.
func (m *MacInputInjector) Close() {}
func keysymToMacKeycode(keysym uint32) uint16 {
if keysym >= 0x61 && keysym <= 0x7a {
return asciiToMacKey[keysym-0x61]
}
if keysym >= 0x41 && keysym <= 0x5a {
return asciiToMacKey[keysym-0x41]
}
if keysym >= 0x30 && keysym <= 0x39 {
return digitToMacKey[keysym-0x30]
}
if code, ok := specialKeyMap[keysym]; ok {
return code
}
return 0xFFFF
}
var asciiToMacKey = [26]uint16{
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
0x10, 0x06,
}
var digitToMacKey = [10]uint16{
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
}
var specialKeyMap = map[uint32]uint16{
// Whitespace and editing
0x0020: 0x31, // space
0xff08: 0x33, // BackSpace
0xff09: 0x30, // Tab
0xff0d: 0x24, // Return
0xff1b: 0x35, // Escape
0xffff: 0x75, // Delete (forward)
// Navigation
0xff50: 0x73, // Home
0xff51: 0x7B, // Left
0xff52: 0x7E, // Up
0xff53: 0x7C, // Right
0xff54: 0x7D, // Down
0xff55: 0x74, // Page_Up
0xff56: 0x79, // Page_Down
0xff57: 0x77, // End
0xff63: 0x72, // Insert (Help on Mac)
// Modifiers
0xffe1: 0x38, // Shift_L
0xffe2: 0x3C, // Shift_R
0xffe3: 0x3B, // Control_L
0xffe4: 0x3E, // Control_R
0xffe5: 0x39, // Caps_Lock
0xffe9: 0x3A, // Alt_L (Option)
0xffea: 0x3D, // Alt_R (Option)
0xffe7: 0x37, // Meta_L (Command)
0xffe8: 0x36, // Meta_R (Command)
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
0xffec: 0x36, // Super_R (Command)
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
0xff7e: 0x3A, // Mode_switch -> Option
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
// Function keys
0xffbe: 0x7A, // F1
0xffbf: 0x78, // F2
0xffc0: 0x63, // F3
0xffc1: 0x76, // F4
0xffc2: 0x60, // F5
0xffc3: 0x61, // F6
0xffc4: 0x62, // F7
0xffc5: 0x64, // F8
0xffc6: 0x65, // F9
0xffc7: 0x6D, // F10
0xffc8: 0x67, // F11
0xffc9: 0x6F, // F12
0xffca: 0x69, // F13
0xffcb: 0x6B, // F14
0xffcc: 0x71, // F15
0xffcd: 0x6A, // F16
0xffce: 0x40, // F17
0xffcf: 0x4F, // F18
0xffd0: 0x50, // F19
0xffd1: 0x5A, // F20
// Punctuation (US keyboard layout, keysym = ASCII code)
0x002d: 0x1B, // minus -
0x003d: 0x18, // equal =
0x005b: 0x21, // bracketleft [
0x005d: 0x1E, // bracketright ]
0x005c: 0x2A, // backslash
0x003b: 0x29, // semicolon ;
0x0027: 0x27, // apostrophe '
0x0060: 0x32, // grave `
0x002c: 0x2B, // comma ,
0x002e: 0x2F, // period .
0x002f: 0x2C, // slash /
// Shifted punctuation (noVNC sends these as separate keysyms)
0x005f: 0x1B, // underscore _ (shift+minus)
0x002b: 0x18, // plus + (shift+equal)
0x007b: 0x21, // braceleft { (shift+[)
0x007d: 0x1E, // braceright } (shift+])
0x007c: 0x2A, // bar | (shift+\)
0x003a: 0x29, // colon : (shift+;)
0x0022: 0x27, // quotedbl " (shift+')
0x007e: 0x32, // tilde ~ (shift+`)
0x003c: 0x2B, // less < (shift+,)
0x003e: 0x2F, // greater > (shift+.)
0x003f: 0x2C, // question ? (shift+/)
0x0021: 0x12, // exclam ! (shift+1)
0x0040: 0x13, // at @ (shift+2)
0x0023: 0x14, // numbersign # (shift+3)
0x0024: 0x15, // dollar $ (shift+4)
0x0025: 0x17, // percent % (shift+5)
0x005e: 0x16, // asciicircum ^ (shift+6)
0x0026: 0x1A, // ampersand & (shift+7)
0x002a: 0x1C, // asterisk * (shift+8)
0x0028: 0x19, // parenleft ( (shift+9)
0x0029: 0x1D, // parenright ) (shift+0)
// Numpad
0xffb0: 0x52, // KP_0
0xffb1: 0x53, // KP_1
0xffb2: 0x54, // KP_2
0xffb3: 0x55, // KP_3
0xffb4: 0x56, // KP_4
0xffb5: 0x57, // KP_5
0xffb6: 0x58, // KP_6
0xffb7: 0x59, // KP_7
0xffb8: 0x5B, // KP_8
0xffb9: 0x5C, // KP_9
0xffae: 0x41, // KP_Decimal
0xffaa: 0x43, // KP_Multiply
0xffab: 0x45, // KP_Add
0xffad: 0x4E, // KP_Subtract
0xffaf: 0x4B, // KP_Divide
0xff8d: 0x4C, // KP_Enter
0xffbd: 0x51, // KP_Equal
}
var _ InputInjector = (*MacInputInjector)(nil)

Some files were not shown because too many files have changed in this diff Show More