Compare commits

..

3 Commits

Author SHA1 Message Date
Viktor Liu
dfe1bba287 Add embedded VNC server with JWT auth, DXGI capture, and dashboard integration 2026-04-14 13:54:31 +02:00
Zoltan Papp
13539543af [client] Fix/grpc retry (#5750)
* [client] Fix flow client Receive retry loop not stopping after Close

Use backoff.Permanent for canceled gRPC errors so Receive returns
immediately instead of retrying until context deadline when the
connection is already closed. Add TestNewClient_PermanentClose to
verify the behavior.

The connectivity.Shutdown check was meaningless because when the connection is
shut down, c.realClient.Events(ctx, grpc.WaitForReady(true)) on the nex line
already fails with codes.Canceled — which is now handled as a permanent error.
The explicit state check was just duplicating what gRPC already reports
through its normal error path.

* [client] remove WaitForReady from stream open call

grpc.WaitForReady(true) parks the RPC call internally until the
connection reaches READY, only unblocking on ctx cancellation.
This means the external backoff.Retry loop in Receive() never gets
control back during a connection outage — it cannot tick, log, or
apply its retry intervals while WaitForReady is blocking.

Removing it restores fail-fast behaviour: Events() returns immediately
with codes.Unavailable when the connection is not ready, which is
exactly what the backoff loop expects. The backoff becomes the single
authority over retry timing and cadence, as originally intended.

* [client] Add connection recreation and improve flow client error handling

Store gRPC dial options on the client to enable connection recreation
on Internal errors (RST_STREAM/PROTOCOL_ERROR). Treat Unauthenticated,
PermissionDenied, and Unimplemented as permanent failures. Unify mutex
usage and add reconnection logging for better observability.

* [client] Remove Unauthenticated, PermissionDenied, and Unimplemented from permanent error handling

* [client] Fix error handling in Receive to properly re-establish stream and improve reconnection messaging

* Fix test

* [client] Add graceful shutdown handling and test for concurrent Close during Receive

Prevent reconnection attempts after client closure by tracking a `closed` flag. Use `backoff.Permanent` for errors caused by operations on a closed client. Add a test to ensure `Close` does not block when `Receive` is actively running.

* [client] Fix connection swap to properly close old gRPC connection

Close the old `gRPC.ClientConn` after successfully swapping to a new connection during reconnection.

* [client] Reset backoff

* [client] Ensure stream closure on error during initialization

* [client] Add test for handling server-side stream closure and reconnection

Introduce `TestReceive_ServerClosesStream` to verify the client's ability to recover and process acknowledgments after the server closes the stream. Enhance test server with a controlled stream closure mechanism.

* [client] Add protocol error simulation and enhance reconnection test

Introduce `connTrackListener` to simulate HTTP/2 RST_STREAM with PROTOCOL_ERROR for testing. Refactor and rename `TestReceive_ServerClosesStream` to `TestReceive_ProtocolErrorStreamReconnect` to verify client recovery on protocol errors.

* [client] Update Close error message in test for clarity

* [client] Fine-tune the tests

* [client] Adjust connection tracking in reconnection test

* [client] Wait for Events handler to exit in RST_STREAM reconnection test

Ensure the old `Events` handler exits fully before proceeding in the reconnection test to avoid dropped acknowledgments on a broken stream. Add a `handlerDone` channel to synchronize handler exits.

* [client] Prevent panic on nil connection during Close

* [client] Refactor connection handling to use explicit target tracking

Introduce `target` field to store the gRPC connection target directly, simplifying reconnections and ensuring consistent connection reuse logic.

* [client] Rename `isCancellation` to `isContextDone` and extend handling for `DeadlineExceeded`

Refactor error handling to include `DeadlineExceeded` scenarios alongside `Canceled`. Update related condition checks for consistency.

* [client] Add connection generation tracking to prevent stale reconnections

Introduce `connGen` to track connection generations and ensure that stale `recreateConnection` calls do not override newer connections. Update stream establishment and reconnection logic to incorporate generation validation.

* [client] Add backoff reset condition to prevent short-lived retry cycles

Refine backoff reset logic to ensure it only occurs for sufficiently long-lived stream connections, avoiding interference with `MaxElapsedTime`.

* [client] Introduce `minHealthyDuration` to refine backoff reset logic

Add `minHealthyDuration` constant to ensure stream retries only reset the backoff timer if the stream survives beyond a minimum duration. Prevents unhealthy, short-lived streams from interfering with `MaxElapsedTime`.

* [client] IPv6 friendly connection

parsedURL.Hostname() strips IPv6 brackets. For http://[::1]:443, this turns it into ::1:443, which is not a valid host:port target for gRPC. Additionally, fmt.Sprintf("%s:%s", hostname, port) produces a trailing colon when the URL has no explicit port—http://example.com becomes example.com:. Both cases break the initial dial and reconnect paths. Use parsedURL.Host directly instead.

* [client] Add `handlerStarted` channel to synchronize stream establishment in tests

Introduce `handlerStarted` channel in the test server to signal when the server-side handler begins, ensuring robust synchronization between client and server during stream establishment. Update relevant test cases to wait for this signal before proceeding.

* [client] Replace `receivedAcks` map with atomic counter and improve stream establishment sync in tests

Refactor acknowledgment tracking in tests to use an `atomic.Int32` counter instead of a map. Replace fixed sleep with robust synchronization by waiting on `handlerStarted` signal for stream establishment.

* [client] Extract `handleReceiveError` to simplify receive logic

Refactor error handling in `receive` to a dedicated `handleReceiveError` method. Streamlines the main logic and isolates error recovery, including backoff reset and connection recreation.

* [client] recreate gRPC ClientConn on every retry to prevent dual backoff

The flow client had two competing retry loops: our custom exponential
backoff and gRPC's internal subchannel reconnection. When establishStream
failed, the same ClientConn was reused, allowing gRPC's internal backoff
state to accumulate and control dial timing independently.

Changes:
- Consolidate error handling into handleRetryableError, which now
 handles context cancellation, permanent errors, backoff reset,
 and connection recreation in a single path
- Call recreateConnection on every retryable error so each retry
 gets a fresh ClientConn with no internal backoff state
- Remove connGen tracking since Receive is sequential and protected
 by a new receiving guard against concurrent calls
- Reduce RandomizationFactor from 1 to 0.5 to avoid near-zero
 backoff intervals
2026-04-13 10:42:24 +02:00
Zoltan Papp
7483fec048 Fix Android internet blackhole caused by stale route re-injection on TUN rebuild (#5865)
extraInitialRoutes() was meant to preserve only the fake IP route
(240.0.0.0/8) across TUN rebuilds, but it re-injected any initial
route missing from the current set. When the management server
advertised exit node routes (0.0.0.0/0) that were later filtered
by the route selector, extraInitialRoutes() re-added them, causing
the Android VPN to capture all traffic with no peer to handle it.

Store the fake IP route explicitly and append only that in notify(),
removing the overly broad initial route diffing.
2026-04-13 09:38:38 +02:00
132 changed files with 9203 additions and 12053 deletions

2
.gitignore vendored
View File

@@ -33,5 +33,3 @@ infrastructure_files/setup-*.env
vendor/
/netbird
client/netbird-electron/
management/server/types/testdata/comparison/
management/server/types/testdata/*.json

View File

@@ -150,6 +150,7 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(vncCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -36,7 +36,10 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
jwtCacheTTLFlag = "jwt-cache-ttl"
// Alias for backward compatibility.
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var (
@@ -61,7 +64,7 @@ var (
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
jwtCacheTTL int
)
func init() {
@@ -71,7 +74,9 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)

View File

@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -371,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
if cmd.Flag(disableVNCAuthFlag).Changed {
req.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
req.SshJWTCacheTTL = &jwtCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -458,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -479,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
if cmd.Flag(disableVNCAuthFlag).Changed {
ic.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &jwtCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
@@ -582,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
@@ -603,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
if cmd.Flag(disableVNCAuthFlag).Changed {
loginRequest.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {

271
client/cmd/vnc.go Normal file
View File

@@ -0,0 +1,271 @@
package cmd
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
vncUsername string
vncHost string
vncMode string
vncListen string
vncNoBrowser bool
vncNoCache bool
)
func init() {
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
}
var vncCmd = &cobra.Command{
Use: "vnc [flags] [user@]host",
Short: "Connect to a NetBird peer via VNC",
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
The target peer must have the VNC server enabled.
Two modes are available:
- attach: view the current physical display (remote support)
- session: start a virtual desktop as the specified user (passwordless login)
Use --listen to start a local proxy for external VNC viewers:
netbird vnc --listen :5900 peer-hostname
vncviewer localhost:5900
Examples:
netbird vnc peer-hostname
netbird vnc --mode session --user alice peer-hostname
netbird vnc --listen :5900 peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: vncFn,
}
func vncFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
if err := parseVNCHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
vncCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runVNC(vncCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-vncCtx.Done()
return nil
case err := <-errCh:
return err
case <-vncCtx.Done():
}
return nil
}
func parseVNCHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return fmt.Errorf("invalid user@host format")
}
if vncUsername == "" {
vncUsername = parts[0]
}
vncHost = parts[1]
if vncMode == "attach" {
vncMode = "session"
}
} else {
vncHost = arg
}
if vncMode == "session" && vncUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
vncUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
vncUsername = currentUser.Username
}
}
return nil
}
func runVNC(ctx context.Context, cmd *cobra.Command) error {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
if vncMode == "session" {
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
} else {
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
}
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
var jwtToken string
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !vncNoBrowser {
browserOpener = util.OpenBrowser
}
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
if err != nil {
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
} else {
jwtToken = token
log.Debug("JWT authentication successful")
}
// Connect to the VNC server on the standard port (5900). The peer's firewall
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
vncAddr := net.JoinHostPort(vncHost, "5900")
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
if err != nil {
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
}
defer vncConn.Close()
// Send session header with mode, username, and JWT.
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
return fmt.Errorf("send VNC header: %w", err)
}
cmd.Printf("VNC connected to %s\n", vncHost)
if vncListen != "" {
return runVNCLocalProxy(ctx, cmd, vncConn)
}
// No --listen flag: inform the user they need to use --listen for external viewers.
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
cmd.Printf("Press Ctrl+C to disconnect.\n")
<-ctx.Done()
return nil
}
const vncDialTimeout = 15 * time.Second
// sendVNCHeader writes the NetBird VNC session header.
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
var modeByte byte
if mode == "session" {
modeByte = 1
}
usernameBytes := []byte(username)
jwtBytes := []byte(jwt)
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
hdr[0] = modeByte
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
off += 2
copy(hdr[off:], jwtBytes)
_, err := conn.Write(hdr)
return err
}
// runVNCLocalProxy listens on the given address and proxies incoming
// connections to the already-established VNC tunnel.
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
listener, err := net.Listen("tcp", vncListen)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncListen, err)
}
defer listener.Close()
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
cmd.Printf("Press Ctrl+C to stop.\n")
go func() {
<-ctx.Done()
listener.Close()
}()
// Accept a single viewer connection. VNC is single-session: the RFB
// handshake completes on vncConn for the first viewer, so subsequent
// viewers would get a mid-stream connection. The loop handles transient
// accept errors until a valid connection arrives.
for {
clientConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
}
log.Debugf("accept VNC proxy client: %v", err)
continue
}
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
// Bidirectional copy.
done := make(chan struct{})
go func() {
io.Copy(vncConn, clientConn)
close(done)
}()
io.Copy(clientConn, vncConn)
<-done
clientConn.Close()
cmd.Printf("VNC viewer disconnected\n")
return nil
}
}

62
client/cmd/vnc_agent.go Normal file
View File

@@ -0,0 +1,62 @@
//go:build windows
package cmd
import (
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var vncAgentPort string
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server in the current user session, listening on
// localhost. It is spawned by the NetBird service (Session 0) via
// CreateProcessAsUser into the interactive console session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Agent's stderr is piped to the service which relogs it.
// Use JSON format with caller info for structured parsing.
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
capturer := vncserver.NewDesktopCapturer()
injector := vncserver.NewWindowsInputInjector()
srv := vncserver.New(capturer, injector, "")
// Auth is handled by the service. The agent verifies a token on each
// connection to ensure only the service process can connect.
// The token is passed via environment variable to avoid exposing it
// in the process command line (visible via tasklist/wmic).
srv.SetDisableAuth(true)
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
if err != nil {
return err
}
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
return err
}
<-cmd.Context().Done()
return srv.Stop()
},
}

16
client/cmd/vnc_flags.go Normal file
View File

@@ -0,0 +1,16 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCAuthFlag = "disable-vnc-auth"
)
var (
serverVNCAllowed bool
disableVNCAuth bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
}

View File

@@ -364,28 +364,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for iptables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for iptables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)

View File

@@ -89,8 +89,6 @@ type router struct {
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
tproxyRules []tproxyRuleEntry
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -1111,92 +1109,3 @@ func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}
// AddTProxyRule adds iptables nat PREROUTING REDIRECT rules for transparent proxy interception.
// Traffic from sources on dstPorts arriving on the WG interface is redirected
// to the transparent proxy listener on redirectPort.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
portStr := fmt.Sprintf("%d", redirectPort)
for _, proto := range []string{"tcp", "udp"} {
srcSpecs := r.buildSourceSpecs(sources)
for _, srcSpec := range srcSpecs {
if len(dstPorts) == 0 {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s: %w", ruleID, proto, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
} else {
for _, port := range dstPorts {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"--dport", fmt.Sprintf("%d", port),
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s/%d: %w", ruleID, proto, port, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
}
}
}
}
return nil
}
// RemoveTProxyRule removes all iptables REDIRECT rules for the given ruleID.
func (r *router) RemoveTProxyRule(ruleID string) error {
var remaining []tproxyRuleEntry
for _, entry := range r.tproxyRules {
if entry.ruleID != ruleID {
remaining = append(remaining, entry)
continue
}
if err := r.iptablesClient.DeleteIfExists(entry.table, entry.chain, entry.spec...); err != nil {
log.Debugf("remove tproxy rule %s: %v", ruleID, err)
}
}
r.tproxyRules = remaining
return nil
}
type tproxyRuleEntry struct {
ruleID string
table string
chain string
spec []string
}
func (r *router) buildSourceSpecs(sources []netip.Prefix) [][]string {
if len(sources) == 0 {
return [][]string{{}} // empty spec = match any source
}
specs := make([][]string, 0, len(sources))
for _, src := range sources {
specs = append(specs, []string{"-s", src.String()})
}
return specs
}

View File

@@ -180,22 +180,6 @@ type Manager interface {
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
// AddTProxyRule adds TPROXY redirect rules for specific source CIDRs and destination ports.
// Traffic from sources on dstPorts is redirected to the transparent proxy on redirectPort.
// Empty dstPorts means redirect all ports.
AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error
// RemoveTProxyRule removes TPROXY redirect rules by ID.
RemoveTProxyRule(ruleID string) error
// AddUDPInspectionHook registers a hook that inspects UDP packets before forwarding.
// The hook receives the raw packet and returns true to drop it.
// Used for QUIC SNI-based blocking. Returns a hook ID for removal.
AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string
// RemoveUDPInspectionHook removes a previously registered inspection hook.
RemoveUDPInspectionHook(hookID string)
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -482,28 +482,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for nftables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for nftables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,

View File

@@ -77,7 +77,6 @@ type router struct {
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -2138,227 +2137,3 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
},
}, nil
}
// AddTProxyRule adds nftables TPROXY redirect rules in the mangle prerouting chain.
// Traffic from sources on dstPorts arriving on the WG interface is redirected to
// the transparent proxy listener on redirectPort.
// Separate rules are created for TCP and UDP protocols.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
// Use the nat redirect chain for DNAT rules.
// TPROXY doesn't work on WG kernel interfaces (socket assignment silently fails),
// so we use DNAT to 127.0.0.1:proxy_port instead. The proxy reads the original
// destination via SO_ORIGINAL_DST (conntrack).
chain := r.chains[chainNameRoutingRdr]
if chain == nil {
return fmt.Errorf("nat redirect chain not initialized")
}
for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP} {
protoName := "tcp"
if proto == unix.IPPROTO_UDP {
protoName = "udp"
}
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, protoName)
if existing, ok := r.rules[ruleKey]; ok && existing.Handle != 0 {
if err := r.decrementSetCounter(existing); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(existing); err != nil {
log.Debugf("remove existing tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
}
exprs, err := r.buildRedirectExprs(proto, sources, dstPorts, redirectPort)
if err != nil {
return fmt.Errorf("build redirect exprs for %s: %w", protoName, err)
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
})
}
// Accept redirected packets in the ACL input chain. After REDIRECT, the
// destination port becomes the proxy port. Without this rule, the ACL filter
// drops the packet. We match on ct state dnat so only REDIRECT'd connections
// are accepted: direct connections to the proxy port are blocked.
inputAcceptKey := fmt.Sprintf("tproxy-%s-input", ruleID)
if _, ok := r.rules[inputAcceptKey]; !ok {
inputChain := &nftables.Chain{
Name: "netbird-acl-input-rules",
Table: r.workTable,
}
r.rules[inputAcceptKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: inputChain,
Exprs: []expr.Any{
// Only accept connections that were REDIRECT'd (ct status dnat)
&expr.Ct{Register: 1, Key: expr.CtKeySTATUS},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(0x20), // IPS_DST_NAT
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
// Accept both TCP and UDP redirected to the proxy port.
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(redirectPort),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(inputAcceptKey),
})
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rules for %s: %w", ruleID, err)
}
return nil
}
// RemoveTProxyRule removes TPROXY redirect rules by ID (both TCP and UDP variants).
func (r *router) RemoveTProxyRule(ruleID string) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var removed int
for _, suffix := range []string{"tcp", "udp", "input"} {
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, suffix)
rule, ok := r.rules[ruleKey]
if !ok {
continue
}
if rule.Handle == 0 {
delete(r.rules, ruleKey)
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(rule); err != nil {
log.Debugf("delete tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
removed++
}
if removed > 0 {
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rule removal for %s: %w", ruleID, err)
}
}
return nil
}
// buildRedirectExprs builds nftables expressions for a REDIRECT rule.
// Matches WG interface ingress, source CIDRs, destination ports, then REDIRECTs to the proxy port.
func (r *router) buildRedirectExprs(proto uint8, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) ([]expr.Any, error) {
var exprs []expr.Any
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname(r.wgIface.Name())},
)
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
)
// Source CIDRs use the named ipset shared with route rules.
if len(sources) > 0 {
srcSet := firewall.NewPrefixSet(sources)
srcExprs, err := r.getIpSet(srcSet, sources, true)
if err != nil {
return nil, fmt.Errorf("get source ipset: %w", err)
}
exprs = append(exprs, srcExprs...)
}
if len(dstPorts) == 1 {
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(dstPorts[0]),
},
)
} else if len(dstPorts) > 1 {
setElements := make([]nftables.SetElement, len(dstPorts))
for i, p := range dstPorts {
setElements[i] = nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)}
}
portSet := &nftables.Set{
Table: r.workTable,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetService,
}
if err := r.conn.AddSet(portSet, setElements); err != nil {
return nil, fmt.Errorf("create port set: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetName: portSet.Name,
SetID: portSet.ID,
},
)
}
// REDIRECT to local proxy port. Changes the destination to the interface's
// primary address + specified port. Conntrack tracks the original destination,
// readable via SO_ORIGINAL_DST.
exprs = append(exprs,
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(redirectPort)},
&expr.Redir{
RegisterProtoMin: 1,
},
)
return exprs, nil
}

View File

@@ -641,45 +641,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// AddTProxyRule delegates to the native firewall for TPROXY rules.
// In userspace mode (no native firewall), this is a no-op since the
// forwarder intercepts traffic directly.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// AddUDPInspectionHook registers a hook for QUIC/UDP inspection via the packet filter.
func (m *Manager) AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string {
m.SetUDPPacketHook(netip.Addr{}, dstPort, hook)
return "udp-inspection"
}
// RemoveUDPInspectionHook removes a previously registered inspection hook.
func (m *Manager) RemoveUDPInspectionHook(_ string) {
m.SetUDPPacketHook(netip.Addr{}, 0, nil)
}
// RemoveTProxyRule delegates to the native firewall for TPROXY rules.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveTProxyRule(ruleID)
}
// IsLocalIP reports whether the given IP belongs to the local machine.
func (m *Manager) IsLocalIP(ip netip.Addr) bool {
return m.localipmanager.IsLocalIP(ip)
}
// GetForwarder returns the userspace packet forwarder, or nil if not initialized.
func (m *Manager) GetForwarder() *forwarder.Forwarder {
return m.forwarder.Load()
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {

View File

@@ -7,7 +7,6 @@ import (
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
@@ -22,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/inspect"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -48,10 +46,6 @@ type Forwarder struct {
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
// proxy is the optional inspection engine.
// When set, TCP connections are handed to the engine for protocol detection
// and rule evaluation. Swapped atomically for lock-free hot-path access.
proxy atomic.Pointer[inspect.Proxy]
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -85,7 +79,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("add protocol address: %s", err)
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
@@ -161,13 +155,6 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
return nil
}
// SetProxy sets the inspection engine. When set, TCP connections are handed
// to it for protocol detection and rule evaluation instead of direct relay.
// Pass nil to disable inspection.
func (f *Forwarder) SetProxy(p *inspect.Proxy) {
f.proxy.Store(p)
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
@@ -180,25 +167,6 @@ func (f *Forwarder) Stop() {
f.stack.Wait()
}
// CheckUDPPacket inspects a UDP payload against proxy rules before injection.
// This is called by the filter for QUIC SNI-based blocking.
// Returns true if the packet should be allowed, false if it should be dropped.
func (f *Forwarder) CheckUDPPacket(payload []byte, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) bool {
p := f.proxy.Load()
if p == nil {
return true
}
dst := netip.AddrPortFrom(dstIP, dstPort)
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(ruleID),
}
action := p.HandleUDPPacket(payload, dst, src)
return action != inspect.ActionBlock
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)

View File

@@ -16,7 +16,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
"github.com/netbirdio/netbird/client/inspect"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -24,86 +23,6 @@ import (
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
// If the inspection engine is configured, accept the connection first and hand it off.
if p := f.proxy.Load(); p != nil {
f.handleTCPWithInspection(r, id, p)
return
}
f.handleTCPDirect(r, id)
}
// handleTCPWithInspection accepts the connection and hands it to the inspection
// engine. For allow decisions, the forwarder does its own relay (passthrough).
// For block/inspect, the engine handles everything internally.
func (f *Forwarder) handleTCPWithInspection(r *tcp.ForwarderRequest, id stack.TransportEndpointID, p *inspect.Proxy) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error1("forwarder: create TCP endpoint for inspection: %v", epErr)
r.Complete(true)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
srcIP := netip.AddrFrom4(id.RemoteAddress.As4())
dstIP := netip.AddrFrom4(id.LocalAddress.As4())
dst := netip.AddrPortFrom(dstIP, id.LocalPort)
var policyID []byte
if ruleID, ok := f.getRuleID(srcIP, dstIP, id.RemotePort, id.LocalPort); ok {
policyID = ruleID
}
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(policyID),
}
f.logger.Trace1("forwarder: handing TCP %v to inspection engine", epID(id))
go func() {
result, err := p.InspectTCP(f.ctx, inConn, dst, src)
if err != nil && err != inspect.ErrBlocked {
f.logger.Debug2("forwarder: inspection error for %v: %v", epID(id), err)
}
// Passthrough: engine returned allow, forwarder does the relay.
if result.PassthroughConn != nil {
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, dialErr := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if dialErr != nil {
f.logger.Trace2("forwarder: passthrough dial error for %v: %v", epID(id), dialErr)
if closeErr := result.PassthroughConn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: close passthrough conn: %v", closeErr)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
f.proxyTCPPassthrough(id, result.PassthroughConn, outConn, ep, flowID)
return
}
// Engine handled it (block/inspect/HTTP). Capture stats and clean up.
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, rxPackets, txPackets)
}()
}
// handleTCPDirect handles TCP connections with direct relay (no proxy).
func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportEndpointID) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
@@ -123,6 +42,7 @@ func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportE
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
@@ -135,6 +55,7 @@ func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportE
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
@@ -152,6 +73,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
@@ -210,66 +132,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
// proxyTCPPassthrough relays traffic between a peeked inbound connection
// (from the inspection engine passthrough) and the outbound connection.
// It accepts net.Conn for inConn since the inspection engine wraps it in a peekConn.
func (f *Forwarder) proxyTCPPassthrough(id stack.TransportEndpointID, inConn net.Conn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough inConn close: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough outConn close: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesIn int64
bytesOut int64
errIn error
errOut error
)
go func() {
bytesIn, errIn = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesOut, errOut = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errIn != nil && !isClosedError(errIn) {
f.logger.Error2("proxyTCPPassthrough: copy error (in→out) for %s: %v", epID(id), errIn)
}
if errOut != nil && !isClosedError(errOut) {
f.logger.Error2("proxyTCPPassthrough: copy error (out→in) for %s: %v", epID(id), errOut)
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: passthrough TCP %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesOut, txPackets, bytesIn)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesOut), uint64(bytesIn), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())

View File

@@ -1,212 +0,0 @@
package inspect
import (
"crypto"
"crypto/x509"
"net"
"net/netip"
"net/url"
"strings"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// InspectResult holds the outcome of connection inspection.
type InspectResult struct {
// Action is the rule evaluation result.
Action Action
// PassthroughConn is the client connection with buffered peeked bytes.
// Non-nil only when Action is ActionAllow and the caller should relay
// (TLS passthrough or non-HTTP/TLS protocol). The caller takes ownership
// and is responsible for closing this connection.
PassthroughConn net.Conn
}
const (
// DefaultTProxyPort is the default TPROXY listener port for kernel mode.
// Override with NB_TPROXY_PORT environment variable.
DefaultTProxyPort = 22080
)
// Action determines how the proxy handles a matched connection.
type Action string
const (
// ActionAllow passes the connection through without decryption.
ActionAllow Action = "allow"
// ActionBlock denies the connection.
ActionBlock Action = "block"
// ActionInspect decrypts (MITM) and inspects the connection.
ActionInspect Action = "inspect"
)
// ProxyMode determines the proxy operating mode.
type ProxyMode string
const (
// ModeBuiltin uses the built-in proxy with rules and optional ICAP.
ModeBuiltin ProxyMode = "builtin"
// ModeEnvoy runs a local envoy sidecar for L7 processing.
// Go manages envoy lifecycle, config generation, and rule evaluation.
// USP path forwards via PROXY protocol v2; kernel path uses nftables redirect.
ModeEnvoy ProxyMode = "envoy"
// ModeExternal forwards all traffic to an external proxy.
ModeExternal ProxyMode = "external"
)
// PolicyID is the management policy identifier associated with a connection.
type PolicyID []byte
// MatchDomain reports whether target matches the pattern.
// If pattern starts with "*.", it matches any subdomain (but not the base itself).
// Otherwise it requires an exact match.
func MatchDomain(pattern, target domain.Domain) bool {
p := pattern.PunycodeString()
t := target.PunycodeString()
if strings.HasPrefix(p, "*.") {
base := p[2:]
return strings.HasSuffix(t, "."+base)
}
return p == t
}
// SourceInfo carries source identity context for rule evaluation.
// The source may be a direct WireGuard peer or a host behind
// a site-to-site gateway.
type SourceInfo struct {
// IP is the original source address from the packet.
IP netip.Addr
// PolicyID is the management policy that allowed this traffic
// through route ACLs.
PolicyID PolicyID
}
// ProtoType identifies a protocol handled by the proxy.
type ProtoType string
const (
ProtoHTTP ProtoType = "http"
ProtoHTTPS ProtoType = "https"
ProtoH2 ProtoType = "h2"
ProtoH3 ProtoType = "h3"
ProtoWebSocket ProtoType = "websocket"
ProtoOther ProtoType = "other"
)
// Rule defines a proxy inspection/filtering rule.
type Rule struct {
// ID uniquely identifies this rule.
ID id.RuleID
// Sources are the source CIDRs this rule applies to.
// Includes both direct peer IPs and routed networks behind gateways.
Sources []netip.Prefix
// Domains are the destination domain patterns to match (via SNI or Host header).
// Supports exact match ("example.com") and wildcard ("*.example.com").
Domains []domain.Domain
// Networks are the destination CIDRs to match.
Networks []netip.Prefix
// Ports are the destination ports to match. Empty means all ports.
Ports []uint16
// Protocols restricts which protocols this rule applies to.
// Empty means all protocols.
Protocols []ProtoType
// Paths are URL path patterns to match (HTTP only, requires inspect for HTTPS).
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
// Empty means all paths.
Paths []string
// Action determines what to do with matched connections.
Action Action
// Priority controls evaluation order. Lower values are evaluated first.
Priority int
}
// ICAPConfig holds ICAP service configuration.
type ICAPConfig struct {
// ReqModURL is the ICAP REQMOD service URL (e.g., icap://server:1344/reqmod).
ReqModURL *url.URL
// RespModURL is the ICAP RESPMOD service URL (e.g., icap://server:1344/respmod).
RespModURL *url.URL
// MaxConnections is the connection pool size. Zero uses a default.
MaxConnections int
}
// TLSConfig holds the MITM CA configuration for TLS inspection.
type TLSConfig struct {
// CA is the certificate authority used to sign dynamic certificates.
CA *x509.Certificate
// CAKey is the CA's private key.
CAKey crypto.PrivateKey
}
// Config holds the transparent proxy configuration.
type Config struct {
// Enabled controls whether the proxy is active.
Enabled bool
// Mode selects built-in or external proxy operation.
Mode ProxyMode
// ExternalURL is the upstream proxy URL for ModeExternal.
// Supports http:// and socks5:// schemes.
ExternalURL *url.URL
// DefaultAction applies when no rule matches a connection.
DefaultAction Action
// RedirectSources are the source CIDRs whose traffic should be intercepted.
// Admin decides: "activate for these users/subnets."
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
RedirectSources []netip.Prefix
// RedirectPorts are the destination ports to intercept. Empty means all ports.
RedirectPorts []uint16
// Rules are the proxy inspection/filtering rules, evaluated in Priority order.
Rules []Rule
// ICAP holds ICAP service configuration. Nil disables ICAP.
ICAP *ICAPConfig
// TLS holds the MITM CA. Nil means no MITM capability (ActionInspect rules ignored).
TLS *TLSConfig
// Envoy configuration (ModeEnvoy only)
Envoy *EnvoyConfig
// ListenAddr is the TPROXY listen address for kernel mode.
// Zero value disables the TPROXY listener.
ListenAddr netip.AddrPort
// WGNetwork is the WireGuard overlay network prefix.
// The proxy blocks dialing destinations inside this network.
WGNetwork netip.Prefix
// LocalIPChecker reports whether an IP belongs to the routing peer.
// Used to prevent SSRF to local services. May be nil.
LocalIPChecker LocalIPChecker
}
// EnvoyConfig holds configuration for the envoy sidecar mode.
type EnvoyConfig struct {
// BinaryPath is the path to the envoy binary.
// Empty means search $PATH for "envoy".
BinaryPath string
// AdminPort is the port for envoy's admin API (health checks, stats).
// Zero means auto-assign.
AdminPort uint16
// Snippets are user-provided config fragments merged into the generated bootstrap.
Snippets *EnvoySnippets
}
// EnvoySnippets holds user-provided YAML fragments for envoy config customization.
// Only safe snippet types are allowed: filters (HTTP and network) and clusters
// needed as dependencies for filter services. Listeners and bootstrap overrides
// are not exposed since we manage the listener and bootstrap.
type EnvoySnippets struct {
// HTTPFilters is YAML injected into the HCM filter chain before the router filter.
// Used for ext_authz, rate limiting, Lua, Wasm, RBAC, JWT auth, etc.
HTTPFilters string
// NetworkFilters is YAML injected into the TLS filter chain before tcp_proxy.
// Used for network-level RBAC, rate limiting, ext_authz on raw TCP.
NetworkFilters string
// Clusters is YAML for additional upstream clusters referenced by filters.
// Needed when filters call external services (ext_authz backend, rate limit service).
Clusters string
}

View File

@@ -1,93 +0,0 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestMatchDomain(t *testing.T) {
tests := []struct {
name string
pattern string
target string
want bool
}{
{
name: "exact match",
pattern: "example.com",
target: "example.com",
want: true,
},
{
name: "exact no match",
pattern: "example.com",
target: "other.com",
want: false,
},
{
name: "wildcard matches subdomain",
pattern: "*.example.com",
target: "foo.example.com",
want: true,
},
{
name: "wildcard matches deep subdomain",
pattern: "*.example.com",
target: "a.b.c.example.com",
want: true,
},
{
name: "wildcard does not match base",
pattern: "*.example.com",
target: "example.com",
want: false,
},
{
name: "wildcard does not match unrelated",
pattern: "*.example.com",
target: "foo.other.com",
want: false,
},
{
name: "case insensitive exact match",
pattern: "Example.COM",
target: "example.com",
want: true,
},
{
name: "case insensitive wildcard match",
pattern: "*.Example.COM",
target: "FOO.example.com",
want: true,
},
{
name: "wildcard does not match partial suffix",
pattern: "*.example.com",
target: "notexample.com",
want: false,
},
{
name: "unicode domain punycode match",
pattern: "*.münchen.de",
target: "sub.xn--mnchen-3ya.de",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pattern, err := domain.FromString(tt.pattern)
require.NoError(t, err)
target, err := domain.FromString(tt.target)
require.NoError(t, err)
got := MatchDomain(pattern, target)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,25 +0,0 @@
package inspect
import (
"net"
"syscall"
)
// newOutboundDialer creates a net.Dialer that clears the socket fwmark.
// In kernel TPROXY mode, accepted connections inherit the TPROXY fwmark.
// Without clearing it, outbound connections from the proxy would match
// the ip rule (fwmark -> local loopback) and loop back to the proxy
// instead of reaching the real destination.
func newOutboundDialer() net.Dialer {
return net.Dialer{
Control: func(_, _ string, c syscall.RawConn) error {
var sockErr error
if err := c.Control(func(fd uintptr) {
sockErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, 0)
}); err != nil {
return err
}
return sockErr
},
}
}

View File

@@ -1,11 +0,0 @@
//go:build !linux
package inspect
import "net"
// newOutboundDialer returns a plain dialer on non-Linux platforms.
// TPROXY is Linux-only, so no fwmark clearing is needed.
func newOutboundDialer() net.Dialer {
return net.Dialer{}
}

View File

@@ -1,298 +0,0 @@
package inspect
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
envoyStartTimeout = 15 * time.Second
envoyHealthInterval = 500 * time.Millisecond
envoyStopTimeout = 10 * time.Second
envoyDrainTime = 5
)
// envoyManager manages the lifecycle of an envoy sidecar process.
type envoyManager struct {
log *log.Entry
cmd *exec.Cmd
configPath string
listenPort uint16
adminPort uint16
cancel context.CancelFunc
blockPagePath string
mu sync.Mutex
running bool
}
// startEnvoy finds the envoy binary, generates config, and spawns the process.
// It blocks until envoy reports healthy or the timeout expires.
func startEnvoy(ctx context.Context, logger *log.Entry, config Config) (*envoyManager, error) {
envCfg := config.Envoy
if envCfg == nil {
return nil, fmt.Errorf("envoy config is nil")
}
binaryPath, err := findEnvoyBinary(envCfg.BinaryPath)
if err != nil {
return nil, fmt.Errorf("find envoy binary: %w", err)
}
// Pick admin port
adminPort := envCfg.AdminPort
if adminPort == 0 {
p, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free admin port: %w", err)
}
adminPort = p
}
// Pick listener port
listenPort, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free listener port: %w", err)
}
// Use a private temp directory (0700) to prevent local attackers from
// replacing the config file between write and envoy read.
configDir, err := os.MkdirTemp("", "nb-envoy-*")
if err != nil {
return nil, fmt.Errorf("create envoy config directory: %w", err)
}
// Write the block page HTML for envoy's direct_response to reference.
blockPagePath := filepath.Join(configDir, "block.html")
blockHTML := fmt.Sprintf(blockPageHTML, "blocked domain", "this domain")
if err := os.WriteFile(blockPagePath, []byte(blockHTML), 0600); err != nil {
return nil, fmt.Errorf("write envoy block page: %w", err)
}
// Generate config with the block page path embedded.
bootstrap, err := generateBootstrap(config, listenPort, adminPort, blockPagePath)
if err != nil {
return nil, fmt.Errorf("generate envoy bootstrap: %w", err)
}
configPath := filepath.Join(configDir, "bootstrap.yaml")
if err := os.WriteFile(configPath, bootstrap, 0600); err != nil {
return nil, fmt.Errorf("write envoy config: %w", err)
}
ctx, cancel := context.WithCancel(ctx)
cmd := exec.CommandContext(ctx, binaryPath,
"-c", configPath,
"--drain-time-s", fmt.Sprintf("%d", envoyDrainTime),
)
// Pipe envoy output to our logger.
cmd.Stdout = &logWriter{entry: logger, level: log.DebugLevel}
cmd.Stderr = &logWriter{entry: logger, level: log.WarnLevel}
if err := cmd.Start(); err != nil {
cancel()
os.Remove(configPath)
return nil, fmt.Errorf("start envoy: %w", err)
}
mgr := &envoyManager{
log: logger,
cmd: cmd,
configPath: configPath,
listenPort: listenPort,
adminPort: adminPort,
blockPagePath: blockPagePath,
cancel: cancel,
running: true,
}
// Wait for envoy to become healthy.
if err := mgr.waitHealthy(ctx); err != nil {
mgr.Stop()
return nil, fmt.Errorf("wait for envoy readiness: %w", err)
}
logger.Infof("inspect: envoy started (pid=%d, listen=%d, admin=%d)", cmd.Process.Pid, listenPort, adminPort)
// Monitor process exit in background.
go mgr.monitor()
return mgr, nil
}
// ListenAddr returns the address envoy listens on for forwarded connections.
func (m *envoyManager) ListenAddr() netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), m.listenPort)
}
// AdminAddr returns the envoy admin API address.
func (m *envoyManager) AdminAddr() string {
return fmt.Sprintf("127.0.0.1:%d", m.adminPort)
}
// Reload writes a new config and sends SIGHUP to envoy.
func (m *envoyManager) Reload(config Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return fmt.Errorf("envoy is not running")
}
bootstrap, err := generateBootstrap(config, m.listenPort, m.adminPort, m.blockPagePath)
if err != nil {
return fmt.Errorf("generate envoy bootstrap: %w", err)
}
if err := os.WriteFile(m.configPath, bootstrap, 0600); err != nil {
return fmt.Errorf("write envoy config: %w", err)
}
if err := signalReload(m.cmd.Process); err != nil {
return fmt.Errorf("signal envoy reload: %w", err)
}
m.log.Debugf("inspect: envoy config reloaded")
return nil
}
// Healthy checks the envoy admin API /ready endpoint.
func (m *envoyManager) Healthy() bool {
resp, err := http.Get(fmt.Sprintf("http://%s/ready", m.AdminAddr()))
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
// Stop terminates the envoy process and cleans up.
func (m *envoyManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
m.running = false
m.cancel()
if m.cmd.Process != nil {
done := make(chan struct{})
go func() {
m.cmd.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(envoyStopTimeout):
m.log.Warnf("inspect: envoy did not exit in %s, killing", envoyStopTimeout)
m.cmd.Process.Kill()
<-done
}
}
os.RemoveAll(filepath.Dir(m.configPath))
m.log.Infof("inspect: envoy stopped")
}
// waitHealthy polls the admin API until envoy is ready or timeout.
func (m *envoyManager) waitHealthy(ctx context.Context) error {
deadline := time.After(envoyStartTimeout)
ticker := time.NewTicker(envoyHealthInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadline:
return fmt.Errorf("envoy not ready after %s", envoyStartTimeout)
case <-ticker.C:
if m.Healthy() {
return nil
}
}
}
}
// monitor watches for unexpected envoy exits.
func (m *envoyManager) monitor() {
err := m.cmd.Wait()
m.mu.Lock()
wasRunning := m.running
m.running = false
m.mu.Unlock()
if wasRunning {
m.log.Errorf("inspect: envoy exited unexpectedly: %v", err)
}
}
// findEnvoyBinary resolves the envoy binary path.
func findEnvoyBinary(configPath string) (string, error) {
if configPath != "" {
if _, err := os.Stat(configPath); err != nil {
return "", fmt.Errorf("envoy binary not found at %s: %w", configPath, err)
}
return configPath, nil
}
path, err := exec.LookPath("envoy")
if err != nil {
return "", fmt.Errorf("envoy not found in PATH: %w", err)
}
return path, nil
}
// findFreePort asks the OS for an available TCP port.
func findFreePort() (uint16, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
port := uint16(ln.Addr().(*net.TCPAddr).Port)
ln.Close()
return port, nil
}
// logWriter adapts log.Entry to io.Writer for piping process output.
type logWriter struct {
entry *log.Entry
level log.Level
}
func (w *logWriter) Write(p []byte) (int, error) {
msg := strings.TrimRight(string(p), "\n\r")
if msg == "" {
return len(p), nil
}
switch w.level {
case log.WarnLevel:
w.entry.Warn(msg)
default:
w.entry.Debug(msg)
}
return len(p), nil
}
// Ensure logWriter satisfies io.Writer.
var _ io.Writer = (*logWriter)(nil)

View File

@@ -1,382 +0,0 @@
package inspect
import (
"bytes"
"fmt"
"strings"
"text/template"
)
// envoyBootstrapTmpl generates the full envoy bootstrap with rule translation.
// TLS rules become per-SNI filter chains; HTTP rules become per-domain virtual hosts.
var envoyBootstrapTmpl = template.Must(template.New("bootstrap").Funcs(template.FuncMap{
"quote": func(s string) string { return fmt.Sprintf("%q", s) },
}).Parse(`node:
id: netbird-inspect
cluster: netbird
admin:
address:
socket_address:
address: 127.0.0.1
port_value: {{.AdminPort}}
static_resources:
listeners:
- name: inspect_listener
address:
socket_address:
address: 127.0.0.1
port_value: {{.ListenPort}}
listener_filters:
- name: envoy.filters.listener.proxy_protocol
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol
- name: envoy.filters.listener.tls_inspector
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector
filter_chains:
{{- /* TLS filter chains: per-SNI block/allow + default */ -}}
{{- range .TLSChains}}
- filter_chain_match:
transport_protocol: tls
{{- if .ServerNames}}
server_names:
{{- range .ServerNames}}
- {{quote .}}
{{- end}}
{{- end}}
filters:
{{$.NetworkFiltersSnippet}} - name: envoy.filters.network.tcp_proxy
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy
stat_prefix: {{.StatPrefix}}
cluster: original_dst
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] tcp %DOWNSTREAM_REMOTE_ADDRESS% -> %UPSTREAM_HOST% %RESPONSE_FLAGS% %DURATION%ms\n"
{{- end}}
{{- /* Plain HTTP filter chain with per-domain virtual hosts */}}
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: inspect_http
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] http %DOWNSTREAM_REMOTE_ADDRESS% %REQ(:AUTHORITY)% %REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %RESPONSE_CODE% %RESPONSE_FLAGS% %DURATION%ms\n"
http_filters:
{{.HTTPFiltersSnippet}} - name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
route_config:
virtual_hosts:
{{- range .VirtualHosts}}
- name: {{.Name}}
domains: [{{.DomainsStr}}]
routes:
{{- range .Routes}}
- match:
prefix: "{{if .PathPrefix}}{{.PathPrefix}}{{else}}/{{end}}"
{{- if .Block}}
direct_response:
status: 403
body:
filename: "{{$.BlockPagePath}}"
{{- else}}
route:
cluster: original_dst
{{- end}}
{{- end}}
{{- end}}
clusters:
- name: original_dst
type: ORIGINAL_DST
lb_policy: CLUSTER_PROVIDED
connect_timeout: 10s
{{.ExtraClusters}}`))
// tlsChain represents a TLS filter chain entry for the template.
// All TLS chains are passthrough (block decisions happen in Go before envoy).
type tlsChain struct {
// ServerNames restricts this chain to specific SNIs. Empty is catch-all.
ServerNames []string
StatPrefix string
}
// envoyRoute represents a single route entry within a virtual host.
type envoyRoute struct {
// PathPrefix for envoy prefix match. Empty means catch-all "/".
PathPrefix string
Block bool
}
// virtualHost represents an HTTP virtual host entry for the template.
type virtualHost struct {
Name string
// DomainsStr is pre-formatted for the template: "a", "b".
DomainsStr string
Routes []envoyRoute
}
type bootstrapData struct {
AdminPort uint16
ListenPort uint16
BlockPagePath string
TLSChains []tlsChain
VirtualHosts []virtualHost
HTTPFiltersSnippet string
NetworkFiltersSnippet string
ExtraClusters string
}
// generateBootstrap produces the envoy bootstrap YAML from the inspect config.
// Translates inspection rules into envoy-native per-SNI and per-domain routing.
// blockPagePath is the path to the HTML block page file served by direct_response.
func generateBootstrap(config Config, listenPort, adminPort uint16, blockPagePath string) ([]byte, error) {
data := bootstrapData{
AdminPort: adminPort,
BlockPagePath: blockPagePath,
ListenPort: listenPort,
TLSChains: buildTLSChains(config),
VirtualHosts: buildVirtualHosts(config),
}
if config.Envoy != nil && config.Envoy.Snippets != nil {
s := config.Envoy.Snippets
data.HTTPFiltersSnippet = indentSnippet(s.HTTPFilters, 18)
data.NetworkFiltersSnippet = indentSnippet(s.NetworkFilters, 12)
data.ExtraClusters = indentSnippet(s.Clusters, 4)
}
var buf bytes.Buffer
if err := envoyBootstrapTmpl.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute bootstrap template: %w", err)
}
return buf.Bytes(), nil
}
// buildTLSChains translates inspection rules into envoy TLS filter chains.
// Block rules -> per-SNI chain routing to blackhole.
// Allow rules (when default=block) -> per-SNI chain routing to original_dst.
// Default chain follows DefaultAction.
func buildTLSChains(config Config) []tlsChain {
// TLS block decisions happen in Go before forwarding to envoy, so we only
// generate allow/passthrough chains here. Envoy can't cleanly close a TLS
// connection without completing a handshake, so blocked SNIs never reach envoy.
var allowed []string
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTPS, ProtoH2) {
continue
}
for _, d := range rule.Domains {
sni := d.PunycodeString()
if rule.Action == ActionAllow || rule.Action == ActionInspect {
allowed = append(allowed, sni)
}
}
}
var chains []tlsChain
if len(allowed) > 0 && config.DefaultAction == ActionBlock {
chains = append(chains, tlsChain{
ServerNames: allowed,
StatPrefix: "tls_allowed",
})
}
// Default catch-all: passthrough (blocked SNIs never arrive here)
chains = append(chains, tlsChain{
StatPrefix: "tls_default",
})
return chains
}
// buildVirtualHosts translates inspection rules into envoy HTTP virtual hosts.
// Groups rules by domain, generates per-path routes within each virtual host.
func buildVirtualHosts(config Config) []virtualHost {
// Group rules by domain for per-domain virtual hosts.
type domainRules struct {
domains []string
routes []envoyRoute
}
domainRouteMap := make(map[string][]envoyRoute)
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTP, ProtoWebSocket) {
continue
}
isBlock := rule.Action == ActionBlock
// Rules without domains or paths are handled by the default action.
if len(rule.Domains) == 0 && len(rule.Paths) == 0 {
continue
}
// Build routes for this rule's paths
var routes []envoyRoute
if len(rule.Paths) > 0 {
for _, p := range rule.Paths {
// Convert our path patterns to envoy prefix match.
// Strip trailing * for envoy prefix matching.
prefix := strings.TrimSuffix(p, "*")
routes = append(routes, envoyRoute{PathPrefix: prefix, Block: isBlock})
}
} else {
routes = append(routes, envoyRoute{Block: isBlock})
}
if len(rule.Domains) > 0 {
for _, d := range rule.Domains {
host := d.PunycodeString()
domainRouteMap[host] = append(domainRouteMap[host], routes...)
}
} else {
// No domain: applies to all, add to default host
domainRouteMap["*"] = append(domainRouteMap["*"], routes...)
}
}
var hosts []virtualHost
idx := 0
// Per-domain virtual hosts with path routes
for domain, routes := range domainRouteMap {
if domain == "*" {
continue
}
// Add a catch-all route after path-specific routes.
// The catch-all follows the default action.
routes = append(routes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: fmt.Sprintf("domain_%d", idx),
DomainsStr: fmt.Sprintf("%q", domain),
Routes: routes,
})
idx++
}
// Default virtual host (catch-all for unmatched domains)
defaultRoutes := domainRouteMap["*"]
defaultRoutes = append(defaultRoutes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: "default",
DomainsStr: `"*"`,
Routes: defaultRoutes,
})
return hosts
}
// ruleTouchesProtocol returns true if the rule's protocol list includes any of the given protocols,
// or if the protocol list is empty (matches all).
func ruleTouchesProtocol(rule Rule, protos ...ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, rp := range rule.Protocols {
for _, p := range protos {
if rp == p {
return true
}
}
}
return false
}
// indentSnippet prepends each line of the YAML snippet with the given number of spaces.
// Returns empty string if snippet is empty.
func indentSnippet(snippet string, spaces int) string {
if snippet == "" {
return ""
}
prefix := make([]byte, spaces)
for i := range prefix {
prefix[i] = ' '
}
var buf bytes.Buffer
for i, line := range bytes.Split([]byte(snippet), []byte("\n")) {
if i > 0 {
buf.WriteByte('\n')
}
if len(line) > 0 {
buf.Write(prefix)
buf.Write(line)
}
}
buf.WriteByte('\n')
return buf.String()
}
// ValidateSnippets checks that user-provided snippets are safe to inject
// into the envoy config. Returns an error describing the first violation found.
//
// Validation rules:
// - Each snippet must be valid YAML (prevents syntax-level injection)
// - Snippets must not contain YAML document separators (--- or ...) that could
// break out of the indentation context
// - Snippets must only contain list items (starting with "- ") at the top level,
// matching what envoy expects for filters and clusters
func ValidateSnippets(snippets *EnvoySnippets) error {
if snippets == nil {
return nil
}
fields := []struct {
name string
value string
}{
{"http_filters", snippets.HTTPFilters},
{"network_filters", snippets.NetworkFilters},
{"clusters", snippets.Clusters},
}
for _, f := range fields {
if f.value == "" {
continue
}
if err := validateSnippetYAML(f.name, f.value); err != nil {
return err
}
}
return nil
}
func validateSnippetYAML(name, snippet string) error {
// Check for YAML document markers that could break template structure.
for _, line := range strings.Split(snippet, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "---" || trimmed == "..." {
return fmt.Errorf("snippet %q: YAML document separators (--- or ...) are not allowed", name)
}
}
// Verify it's valid YAML by checking it doesn't cause template execution issues.
// We can't import yaml.v3 here without adding a dependency, so we do structural checks.
// Check for null bytes or control characters that could confuse YAML parsers.
for i, b := range []byte(snippet) {
if b == 0 {
return fmt.Errorf("snippet %q: null byte at position %d", name, i)
}
if b < 0x09 || (b > 0x0D && b < 0x20 && b != 0x1B) {
return fmt.Errorf("snippet %q: control character 0x%02x at position %d", name, b, i)
}
}
return nil
}

View File

@@ -1,88 +0,0 @@
package inspect
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
)
// PROXY protocol v2 constants (RFC 7239 / HAProxy spec)
var proxyV2Signature = [12]byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51,
0x55, 0x49, 0x54, 0x0A,
}
const (
proxyV2VersionCommand = 0x21 // version 2, PROXY command
proxyV2FamilyTCP4 = 0x11 // AF_INET, STREAM
proxyV2FamilyTCP6 = 0x21 // AF_INET6, STREAM
)
// forwardToEnvoy forwards a connection to the given envoy sidecar via PROXY protocol v2.
// The caller provides the envoy manager snapshot to avoid accessing p.envoy without lock.
func (p *Proxy) forwardToEnvoy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo, em *envoyManager) error {
envoyAddr := em.ListenAddr()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", envoyAddr.String())
if err != nil {
return fmt.Errorf("dial envoy at %s: %w", envoyAddr, err)
}
defer func() {
if err := conn.Close(); err != nil {
p.log.Debugf("close envoy conn: %v", err)
}
}()
if err := writeProxyV2Header(conn, src.IP, dst); err != nil {
return fmt.Errorf("write PROXY v2 header: %w", err)
}
p.log.Tracef("envoy: forwarded %s -> %s via PROXY v2", src.IP, dst)
return relay(ctx, pconn, conn)
}
// writeProxyV2Header writes a PROXY protocol v2 header to w.
// The header encodes the original source IP and the destination address:port.
func writeProxyV2Header(w net.Conn, srcIP netip.Addr, dst netip.AddrPort) error {
srcIP = srcIP.Unmap()
dstIP := dst.Addr().Unmap()
var (
family byte
addrs []byte
)
if srcIP.Is4() && dstIP.Is4() {
family = proxyV2FamilyTCP4
s4 := srcIP.As4()
d4 := dstIP.As4()
addrs = make([]byte, 12) // 4+4+2+2
copy(addrs[0:4], s4[:])
copy(addrs[4:8], d4[:])
binary.BigEndian.PutUint16(addrs[8:10], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[10:12], dst.Port())
} else {
family = proxyV2FamilyTCP6
s16 := srcIP.As16()
d16 := dstIP.As16()
addrs = make([]byte, 36) // 16+16+2+2
copy(addrs[0:16], s16[:])
copy(addrs[16:32], d16[:])
binary.BigEndian.PutUint16(addrs[32:34], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[34:36], dst.Port())
}
// Header: signature(12) + ver_cmd(1) + family(1) + len(2) + addrs
header := make([]byte, 16+len(addrs))
copy(header[0:12], proxyV2Signature[:])
header[12] = proxyV2VersionCommand
header[13] = family
binary.BigEndian.PutUint16(header[14:16], uint16(len(addrs)))
copy(header[16:], addrs)
_, err := w.Write(header)
return err
}

View File

@@ -1,13 +0,0 @@
//go:build !windows
package inspect
import (
"os"
"syscall"
)
// signalReload sends SIGHUP to the envoy process to trigger config reload.
func signalReload(p *os.Process) error {
return p.Signal(syscall.SIGHUP)
}

View File

@@ -1,13 +0,0 @@
//go:build windows
package inspect
import (
"fmt"
"os"
)
// signalReload is not supported on Windows. Envoy must be restarted.
func signalReload(_ *os.Process) error {
return fmt.Errorf("envoy config reload via signal not supported on Windows")
}

View File

@@ -1,229 +0,0 @@
package inspect
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"time"
)
const (
externalDialTimeout = 10 * time.Second
)
// handleExternal forwards the connection to an external proxy.
// For TLS connections, it uses HTTP CONNECT to tunnel through the proxy.
// For HTTP connections, it rewrites the request to use the proxy.
func (p *Proxy) handleExternal(ctx context.Context, pconn *peekConn, dst netip.AddrPort) error {
p.mu.RLock()
proxyURL := p.config.ExternalURL
p.mu.RUnlock()
if proxyURL == nil {
return fmt.Errorf("external proxy URL not configured")
}
switch proxyURL.Scheme {
case "http", "https":
return p.externalHTTPProxy(ctx, pconn, dst, proxyURL)
case "socks5":
return p.externalSOCKS5(ctx, pconn, dst, proxyURL)
default:
return fmt.Errorf("unsupported external proxy scheme: %s", proxyURL.Scheme)
}
}
// externalHTTPProxy tunnels through an HTTP proxy using CONNECT.
func (p *Proxy) externalHTTPProxy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "8080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial external proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close external proxy conn: %v", err)
}
}()
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", dst.String(), dst.String())
if proxyURL.User != nil {
connectReq += "Proxy-Authorization: Basic " + basicAuth(proxyURL.User) + "\r\n"
}
connectReq += "\r\n"
if _, err := io.WriteString(proxyConn, connectReq); err != nil {
return fmt.Errorf("send CONNECT to proxy: %w", err)
}
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), nil)
if err != nil {
return fmt.Errorf("read CONNECT response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close CONNECT resp body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
}
return relay(ctx, pconn, proxyConn)
}
// externalSOCKS5 tunnels through a SOCKS5 proxy.
func (p *Proxy) externalSOCKS5(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "1080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial SOCKS5 proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close SOCKS5 proxy conn: %v", err)
}
}()
if err := socks5Handshake(proxyConn, dst, proxyURL.User); err != nil {
return fmt.Errorf("SOCKS5 handshake: %w", err)
}
return relay(ctx, pconn, proxyConn)
}
// socks5Handshake performs the SOCKS5 handshake to connect through the proxy.
func socks5Handshake(conn net.Conn, dst netip.AddrPort, userinfo *url.Userinfo) error {
needAuth := userinfo != nil
// Greeting
var methods []byte
if needAuth {
methods = []byte{0x00, 0x02} // no auth, username/password
} else {
methods = []byte{0x00} // no auth
}
greeting := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(greeting); err != nil {
return fmt.Errorf("send greeting: %w", err)
}
// Server method selection
var methodResp [2]byte
if _, err := io.ReadFull(conn, methodResp[:]); err != nil {
return fmt.Errorf("read method selection: %w", err)
}
if methodResp[0] != 0x05 {
return fmt.Errorf("unexpected SOCKS version: %d", methodResp[0])
}
// Handle authentication if selected
if methodResp[1] == 0x02 {
if err := socks5Auth(conn, userinfo); err != nil {
return err
}
} else if methodResp[1] != 0x00 {
return fmt.Errorf("unsupported SOCKS5 auth method: %d", methodResp[1])
}
// Connection request
addr := dst.Addr()
var addrBytes []byte
if addr.Is4() {
a4 := addr.As4()
addrBytes = append([]byte{0x01}, a4[:]...) // IPv4
} else {
a16 := addr.As16()
addrBytes = append([]byte{0x04}, a16[:]...) // IPv6
}
port := dst.Port()
connectReq := append([]byte{0x05, 0x01, 0x00}, addrBytes...)
connectReq = append(connectReq, byte(port>>8), byte(port))
if _, err := conn.Write(connectReq); err != nil {
return fmt.Errorf("send connect request: %w", err)
}
// Read response (minimum 10 bytes for IPv4)
var respHeader [4]byte
if _, err := io.ReadFull(conn, respHeader[:]); err != nil {
return fmt.Errorf("read connect response: %w", err)
}
if respHeader[1] != 0x00 {
return fmt.Errorf("SOCKS5 connect failed: status %d", respHeader[1])
}
// Skip bound address
switch respHeader[3] {
case 0x01: // IPv4
var skip [4 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv4 address: %w", err)
}
case 0x04: // IPv6
var skip [16 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv6 address: %w", err)
}
case 0x03: // Domain
var dLen [1]byte
if _, err := io.ReadFull(conn, dLen[:]); err != nil {
return fmt.Errorf("read domain length: %w", err)
}
skip := make([]byte, int(dLen[0])+2)
if _, err := io.ReadFull(conn, skip); err != nil {
return fmt.Errorf("read SOCKS5 bound domain address: %w", err)
}
}
return nil
}
func socks5Auth(conn net.Conn, userinfo *url.Userinfo) error {
if userinfo == nil {
return fmt.Errorf("SOCKS5 auth required but no credentials provided")
}
user := userinfo.Username()
pass, _ := userinfo.Password()
// Username/password auth (RFC 1929)
auth := []byte{0x01, byte(len(user))}
auth = append(auth, []byte(user)...)
auth = append(auth, byte(len(pass)))
auth = append(auth, []byte(pass)...)
if _, err := conn.Write(auth); err != nil {
return fmt.Errorf("send auth: %w", err)
}
var resp [2]byte
if _, err := io.ReadFull(conn, resp[:]); err != nil {
return fmt.Errorf("read auth response: %w", err)
}
if resp[1] != 0x00 {
return fmt.Errorf("SOCKS5 auth failed: status %d", resp[1])
}
return nil
}
func basicAuth(userinfo *url.Userinfo) string {
user := userinfo.Username()
pass, _ := userinfo.Password()
return base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
}

View File

@@ -1,532 +0,0 @@
package inspect
import (
"bufio"
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
headerUpgrade = "Upgrade"
valueWebSocket = "websocket"
)
// inspectHTTP runs the HTTP inspection pipeline on decrypted traffic.
// It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy),
// and WebSocket upgrade detection.
func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error {
if proto == "h2" {
return p.inspectH2(ctx, client, remote, dst, sni, src)
}
return p.inspectH1(ctx, client, remote, dst, sni, src)
}
// inspectH1 handles HTTP/1.1 request-response inspection in a loop.
func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
clientReader := bufio.NewReader(client)
remoteReader := bufio.NewReader(remote)
for {
if ctx.Err() != nil {
return ctx.Err()
}
// Set idle timeout between requests to prevent connection hogging.
if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return fmt.Errorf("set idle deadline: %w", err)
}
req, err := http.ReadRequest(clientReader)
if err != nil {
if isClosedErr(err) {
return nil
}
return fmt.Errorf("read HTTP request: %w", err)
}
if err := client.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
// Re-evaluate rules based on Host header if SNI was empty
host := hostFromRequest(req, sni)
// Domain fronting: Host header doesn't match TLS SNI
if isDomainFronting(req, sni) {
p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
proto := ProtoHTTP
if isWebSocketUpgrade(req) {
proto = ProtoWebSocket
}
action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path)
if action == ActionBlock {
p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action)
// ICAP REQMOD: send request for inspection.
// Snapshot ICAP client under lock to avoid use-after-close races.
p.mu.RLock()
icap := p.icap
p.mu.RUnlock()
if icap != nil {
modified, err := icap.ReqMod(req)
if err != nil {
p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
// Fail-closed: block on ICAP error
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP REQMOD: %w", err)
}
req = modified
}
if isWebSocketUpgrade(req) {
return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader)
}
removeHopByHopHeaders(req.Header)
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward request: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read HTTP response: %w", err)
}
// ICAP RESPMOD: send response for inspection
if icap != nil {
modified, err := icap.RespMod(req, resp)
if err != nil {
p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP RESPMOD: %w", err)
}
resp = modified
}
removeHopByHopHeaders(resp.Header)
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close resp body: %v", closeErr)
}
return fmt.Errorf("forward response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
// Connection: close means we're done
if resp.Close || req.Close {
return nil
}
}
}
// inspectH2 proxies HTTP/2 traffic using Go's http stack.
// Client and remote are already-established TLS connections with h2 negotiated.
func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
// For h2 MITM inspection, we use a local http.Server reading from the client
// connection and an http.Transport writing to the remote connection.
//
// The transport is configured to use the existing TLS connection to the
// real server. The handler inspects each request/response pair.
transport := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
ForceAttemptHTTP2: true,
}
handler := &h2InspectionHandler{
proxy: p,
transport: transport,
dst: dst,
sni: sni,
src: src,
}
server := &http.Server{
Handler: handler,
}
// Serve the single client connection.
// ServeConn blocks until the connection is done.
errCh := make(chan error, 1)
go func() {
// http.Server doesn't have a direct ServeConn for h2,
// so we use Serve with a single-connection listener.
ln := &singleConnListener{conn: client}
errCh <- server.Serve(ln)
}()
select {
case <-ctx.Done():
if err := server.Close(); err != nil {
p.log.Debugf("close h2 server: %v", err)
}
return ctx.Err()
case err := <-errCh:
if err == http.ErrServerClosed {
return nil
}
return err
}
}
// h2InspectionHandler inspects each HTTP/2 request/response pair.
type h2InspectionHandler struct {
proxy *Proxy
transport http.RoundTripper
dst netip.AddrPort
sni domain.Domain
src SourceInfo
}
func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
host := hostFromRequest(req, h.sni)
if isDomainFronting(req, h.sni) {
h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString())
writeBlockPage(w, host)
return
}
action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path)
if action == ActionBlock {
h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockPage(w, host)
return
}
// ICAP REQMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.ReqMod(req)
if err != nil {
h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
req = modified
}
// Forward to upstream
req.URL.Scheme = "https"
req.URL.Host = h.sni.PunycodeString()
req.RequestURI = ""
resp, err := h.transport.RoundTrip(req)
if err != nil {
h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
h.proxy.log.Debugf("close h2 resp body: %v", err)
}
}()
// ICAP RESPMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.RespMod(req, resp)
if err != nil {
h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
resp = modified
}
// Copy response headers and body
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
h.proxy.log.Debugf("h2 response copy error: %v", err)
}
}
// handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally.
func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error {
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward WebSocket upgrade: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read WebSocket upgrade response: %w", err)
}
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close ws resp body: %v", closeErr)
}
return fmt.Errorf("forward WebSocket upgrade response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close ws resp body: %v", err)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode)
}
p.log.Tracef("allow: WebSocket upgrade for %s", req.Host)
// Relay WebSocket frames bidirectionally.
// clientReader/remoteReader may have buffered data.
clientConn := mergeReadWriter(clientReader, client)
remoteConn := mergeReadWriter(remoteReader, remote)
return relayRW(ctx, clientConn, remoteConn)
}
// hostFromRequest extracts a domain.Domain from the HTTP request Host header,
// falling back to the SNI if Host is empty or an IP.
func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain {
host := req.Host
if host == "" {
return fallback
}
// Strip port if present
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// If it's an IP address, use the SNI fallback
if _, err := netip.ParseAddr(host); err == nil {
return fallback
}
d, err := domain.FromString(host)
if err != nil {
return fallback
}
return d
}
// isDomainFronting detects domain fronting: the Host header doesn't match the
// SNI used during the TLS handshake. Only meaningful when SNI is non-empty
// (i.e., we're in MITM mode and know the original SNI).
func isDomainFronting(req *http.Request, sni domain.Domain) bool {
if sni == "" {
return false
}
host := hostFromRequest(req, "")
if host == "" {
return false
}
// Host should match SNI or be a subdomain of SNI
if host == sni {
return false
}
// Allow www.example.com when SNI is example.com
sniStr := sni.PunycodeString()
hostStr := host.PunycodeString()
if strings.HasSuffix(hostStr, "."+sniStr) {
return false
}
return true
}
func isWebSocketUpgrade(req *http.Request) bool {
return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket)
}
// writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path).
func writeBlockPage(w http.ResponseWriter, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
w.WriteHeader(http.StatusForbidden)
io.WriteString(w, body)
}
func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
resp := &http.Response{
StatusCode: http.StatusForbidden,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
ContentLength: int64(len(body)),
Body: io.NopCloser(strings.NewReader(body)),
}
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Header.Set("Connection", "close")
resp.Header.Set("Cache-Control", "no-store")
_ = resp.Write(w)
}
// blockPageHTML is the self-contained HTML block page.
// Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain.
const blockPageHTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>Blocked - %s</title>
<style>
*{margin:0;padding:0;box-sizing:border-box}
body{background:#181a1d;color:#d1d5db;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;min-height:100vh;display:flex;align-items:center;justify-content:center}
.c{text-align:center;max-width:460px;padding:2rem}
.shield{width:56px;height:56px;margin:0 auto 1.5rem;border-radius:16px;background:#2b2f33;display:flex;align-items:center;justify-content:center}
.shield svg{width:28px;height:28px;color:#f68330}
.code{font-size:.8rem;font-weight:500;color:#f68330;font-family:ui-monospace,monospace;letter-spacing:.05em;margin-bottom:.5rem}
h1{font-size:1.5rem;font-weight:600;color:#f4f4f5;margin-bottom:.5rem}
p{font-size:.95rem;line-height:1.5;color:#9ca3af;margin-bottom:1.75rem}
.domain{display:inline-block;background:#25282d;border:1px solid #32363d;border-radius:6px;padding:.15rem .5rem;font-family:ui-monospace,monospace;font-size:.85rem;color:#d1d5db}
.footer{font-size:.7rem;color:#6b7280;margin-top:2rem;letter-spacing:.03em}
.footer a{color:#6b7280;text-decoration:none}
.footer a:hover{color:#9ca3af}
</style>
</head>
<body>
<div class="c">
<div class="shield"><svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m0-10.036A11.959 11.959 0 0 1 3.598 6 11.99 11.99 0 0 0 3 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751A11.96 11.96 0 0 0 12 3.714Z"/></svg></div>
<div class="code">403 BLOCKED</div>
<h1>Access Denied</h1>
<p>This connection to <span class="domain">%s</span> has been blocked by your organization's network policy.</p>
<div class="footer">Protected by <a href="https://netbird.io" target="_blank" rel="noopener">NetBird</a></div>
</div>
</body>
</html>`
// singleConnListener is a net.Listener that yields a single connection.
type singleConnListener struct {
conn net.Conn
once sync.Once
ch chan struct{}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
var accepted bool
l.once.Do(func() {
l.ch = make(chan struct{})
accepted = true
})
if accepted {
return l.conn, nil
}
// Block until Close
<-l.ch
return nil, net.ErrClosed
}
func (l *singleConnListener) Close() error {
l.once.Do(func() {
l.ch = make(chan struct{})
})
select {
case <-l.ch:
default:
close(l.ch)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}
type readWriter struct {
io.Reader
io.Writer
}
func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
return &readWriter{Reader: r, Writer: w}
}
// relayRW copies data bidirectionally between two ReadWriters.
func relayRW(ctx context.Context, a, b io.ReadWriter) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(b, a)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(a, b)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// hopByHopHeaders are HTTP/1.1 headers that apply to a single connection
// and must not be forwarded by a proxy (RFC 7230, Section 6.1).
var hopByHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
// removeHopByHopHeaders strips hop-by-hop headers from h.
// Also removes headers listed in the Connection header value.
func removeHopByHopHeaders(h http.Header) {
// First, remove any headers named in the Connection header
for _, connHeader := range h["Connection"] {
for _, name := range strings.Split(connHeader, ",") {
h.Del(strings.TrimSpace(name))
}
}
for _, name := range hopByHopHeaders {
h.Del(name)
}
}

View File

@@ -1,479 +0,0 @@
package inspect
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
icapVersion = "ICAP/1.0"
icapDefaultPort = "1344"
icapConnTimeout = 30 * time.Second
icapRWTimeout = 60 * time.Second
icapMaxPoolSize = 8
icapIdleTimeout = 60 * time.Second
icapMaxRespSize = 4 * 1024 * 1024 // 4 MB
)
// ICAPClient implements an ICAP (RFC 3507) client with persistent connection pooling.
type ICAPClient struct {
reqModURL *url.URL
respModURL *url.URL
pool chan *icapConn
mu sync.Mutex
log *log.Entry
maxPool int
}
type icapConn struct {
conn net.Conn
reader *bufio.Reader
lastUse time.Time
}
// NewICAPClient creates an ICAP client. Either or both URLs may be nil
// to disable that mode.
func NewICAPClient(logger *log.Entry, cfg *ICAPConfig) *ICAPClient {
maxPool := cfg.MaxConnections
if maxPool <= 0 {
maxPool = icapMaxPoolSize
}
return &ICAPClient{
reqModURL: cfg.ReqModURL,
respModURL: cfg.RespModURL,
pool: make(chan *icapConn, maxPool),
log: logger,
maxPool: maxPool,
}
}
// ReqMod sends an HTTP request to the ICAP REQMOD service for inspection.
// Returns the (possibly modified) request, or the original if ICAP returns 204.
// Returns nil, nil if REQMOD is not configured.
func (c *ICAPClient) ReqMod(req *http.Request) (*http.Request, error) {
if c.reqModURL == nil {
return req, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
respBody, err := c.send("REQMOD", c.reqModURL, reqBuf.Bytes(), nil)
if err != nil {
return nil, err
}
if respBody == nil {
return req, nil
}
modified, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(respBody)))
if err != nil {
return nil, fmt.Errorf("parse ICAP modified request: %w", err)
}
return modified, nil
}
// RespMod sends an HTTP response to the ICAP RESPMOD service for inspection.
// Returns the (possibly modified) response, or the original if ICAP returns 204.
// Returns nil, nil if RESPMOD is not configured.
func (c *ICAPClient) RespMod(req *http.Request, resp *http.Response) (*http.Response, error) {
if c.respModURL == nil {
return resp, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
var respBuf bytes.Buffer
if err := resp.Write(&respBuf); err != nil {
return nil, fmt.Errorf("serialize response: %w", err)
}
respBody, err := c.send("RESPMOD", c.respModURL, reqBuf.Bytes(), respBuf.Bytes())
if err != nil {
return nil, err
}
if respBody == nil {
// 204 No Content: ICAP server didn't modify the response.
// Reconstruct from the buffered copy since resp.Body was consumed by Write.
reconstructed, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBuf.Bytes())), req)
if err != nil {
return nil, fmt.Errorf("reconstruct response after ICAP 204: %w", err)
}
return reconstructed, nil
}
modified, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBody)), req)
if err != nil {
return nil, fmt.Errorf("parse ICAP modified response: %w", err)
}
return modified, nil
}
// Close drains and closes all pooled connections.
func (c *ICAPClient) Close() {
close(c.pool)
for ic := range c.pool {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close ICAP connection: %v", err)
}
}
}
// send executes an ICAP request and returns the encapsulated body from the response.
// Returns nil body for 204 No Content (no modification).
// Retries once on stale pooled connection (EOF on read).
func (c *ICAPClient) send(method string, serviceURL *url.URL, reqData, respData []byte) ([]byte, error) {
statusCode, headers, body, err := c.trySend(method, serviceURL, reqData, respData)
if err != nil && isStaleConnErr(err) {
// Retry once with a fresh connection (stale pool entry).
c.log.Debugf("ICAP %s: retrying after stale connection: %v", method, err)
statusCode, headers, body, err = c.trySend(method, serviceURL, reqData, respData)
}
if err != nil {
return nil, err
}
switch statusCode {
case 204:
return nil, nil
case 200:
return body, nil
default:
c.log.Debugf("ICAP %s returned status %d, headers: %v", method, statusCode, headers)
return nil, fmt.Errorf("ICAP %s: status %d", method, statusCode)
}
}
func (c *ICAPClient) trySend(method string, serviceURL *url.URL, reqData, respData []byte) (int, textproto.MIMEHeader, []byte, error) {
ic, err := c.getConn(serviceURL)
if err != nil {
return 0, nil, nil, fmt.Errorf("get ICAP connection: %w", err)
}
if err := c.writeRequest(ic, method, serviceURL, reqData, respData); err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after write error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("write ICAP %s: %w", method, err)
}
statusCode, headers, body, err := c.readResponse(ic)
if err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after read error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("read ICAP response: %w", err)
}
c.putConn(ic)
return statusCode, headers, body, nil
}
func isStaleConnErr(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "EOF") || strings.Contains(s, "broken pipe") || strings.Contains(s, "connection reset")
}
func (c *ICAPClient) writeRequest(ic *icapConn, method string, serviceURL *url.URL, reqData, respData []byte) error {
if err := ic.conn.SetWriteDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
// For RESPMOD, split the serialized HTTP response into headers and body.
// The body must be sent chunked per RFC 3507.
var respHdr, respBody []byte
if respData != nil {
if idx := bytes.Index(respData, []byte("\r\n\r\n")); idx >= 0 {
respHdr = respData[:idx+4] // include the \r\n\r\n separator
respBody = respData[idx+4:]
} else {
respHdr = respData
}
}
var buf bytes.Buffer
// Request line
fmt.Fprintf(&buf, "%s %s %s\r\n", method, serviceURL.String(), icapVersion)
// Headers
host := serviceURL.Host
fmt.Fprintf(&buf, "Host: %s\r\n", host)
fmt.Fprintf(&buf, "Connection: keep-alive\r\n")
fmt.Fprintf(&buf, "Allow: 204\r\n")
// Build Encapsulated header
offset := 0
var encapParts []string
if reqData != nil {
encapParts = append(encapParts, fmt.Sprintf("req-hdr=%d", offset))
offset += len(reqData)
}
if respHdr != nil {
encapParts = append(encapParts, fmt.Sprintf("res-hdr=%d", offset))
offset += len(respHdr)
}
if len(respBody) > 0 {
encapParts = append(encapParts, fmt.Sprintf("res-body=%d", offset))
} else {
encapParts = append(encapParts, fmt.Sprintf("null-body=%d", offset))
}
fmt.Fprintf(&buf, "Encapsulated: %s\r\n", strings.Join(encapParts, ", "))
fmt.Fprintf(&buf, "\r\n")
// Encapsulated sections
if reqData != nil {
buf.Write(reqData)
}
if respHdr != nil {
buf.Write(respHdr)
}
// Body in chunked encoding (only when there is an actual body section).
// Per RFC 3507 Section 4.4.1, null-body must not include any entity data.
if len(respBody) > 0 {
fmt.Fprintf(&buf, "%x\r\n", len(respBody))
buf.Write(respBody)
buf.WriteString("\r\n")
buf.WriteString("0\r\n\r\n")
}
_, err := ic.conn.Write(buf.Bytes())
return err
}
func (c *ICAPClient) readResponse(ic *icapConn) (int, textproto.MIMEHeader, []byte, error) {
if err := ic.conn.SetReadDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return 0, nil, nil, fmt.Errorf("set read deadline: %w", err)
}
tp := textproto.NewReader(ic.reader)
// Status line: "ICAP/1.0 200 OK"
statusLine, err := tp.ReadLine()
if err != nil {
return 0, nil, nil, fmt.Errorf("read status line: %w", err)
}
statusCode, err := parseICAPStatus(statusLine)
if err != nil {
return 0, nil, nil, err
}
// Headers
headers, err := tp.ReadMIMEHeader()
if err != nil {
return statusCode, nil, nil, fmt.Errorf("read ICAP headers: %w", err)
}
if statusCode == 204 {
return statusCode, headers, nil, nil
}
// Read encapsulated body based on Encapsulated header
body, err := c.readEncapsulatedBody(ic.reader, headers)
if err != nil {
return statusCode, headers, nil, fmt.Errorf("read encapsulated body: %w", err)
}
return statusCode, headers, body, nil
}
func (c *ICAPClient) readEncapsulatedBody(r *bufio.Reader, headers textproto.MIMEHeader) ([]byte, error) {
encap := headers.Get("Encapsulated")
if encap == "" {
return nil, nil
}
// Find the body offset from the Encapsulated header.
// The last section with a non-zero offset is the body.
// Read everything from the reader as the encapsulated content.
var totalSize int
parts := strings.Split(encap, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
eqIdx := strings.Index(part, "=")
if eqIdx < 0 {
continue
}
offset, err := strconv.Atoi(strings.TrimSpace(part[eqIdx+1:]))
if err != nil {
continue
}
if offset > totalSize {
totalSize = offset
}
}
// Read all available encapsulated data (headers + body)
// The body section uses chunked encoding per RFC 3507
var buf bytes.Buffer
if totalSize > 0 {
// Read the header sections (everything before the body offset)
headerBytes := make([]byte, totalSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("read encapsulated headers: %w", err)
}
buf.Write(headerBytes)
}
// Read chunked body
chunked := newChunkedReader(r)
body, err := io.ReadAll(io.LimitReader(chunked, icapMaxRespSize))
if err != nil {
return nil, fmt.Errorf("read chunked body: %w", err)
}
buf.Write(body)
return buf.Bytes(), nil
}
func (c *ICAPClient) getConn(serviceURL *url.URL) (*icapConn, error) {
// Try to get a pooled connection
for {
select {
case ic := <-c.pool:
if time.Since(ic.lastUse) > icapIdleTimeout {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close idle ICAP connection: %v", err)
}
continue
}
return ic, nil
default:
return c.dialConn(serviceURL)
}
}
}
func (c *ICAPClient) putConn(ic *icapConn) {
c.mu.Lock()
defer c.mu.Unlock()
ic.lastUse = time.Now()
select {
case c.pool <- ic:
default:
// Pool full, close connection.
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close excess ICAP connection: %v", err)
}
}
}
func (c *ICAPClient) dialConn(serviceURL *url.URL) (*icapConn, error) {
host := serviceURL.Host
if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, icapDefaultPort)
}
conn, err := net.DialTimeout("tcp", host, icapConnTimeout)
if err != nil {
return nil, fmt.Errorf("dial ICAP %s: %w", host, err)
}
return &icapConn{
conn: conn,
reader: bufio.NewReader(conn),
lastUse: time.Now(),
}, nil
}
func parseICAPStatus(line string) (int, error) {
// "ICAP/1.0 200 OK"
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
return 0, fmt.Errorf("malformed ICAP status line: %q", line)
}
code, err := strconv.Atoi(parts[1])
if err != nil {
return 0, fmt.Errorf("parse ICAP status code %q: %w", parts[1], err)
}
return code, nil
}
// chunkedReader reads ICAP chunked encoding (same as HTTP chunked, terminated by "0\r\n\r\n").
type chunkedReader struct {
r *bufio.Reader
remaining int
done bool
}
func newChunkedReader(r *bufio.Reader) *chunkedReader {
return &chunkedReader{r: r}
}
func (cr *chunkedReader) Read(p []byte) (int, error) {
if cr.done {
return 0, io.EOF
}
if cr.remaining == 0 {
// Read chunk size line
line, err := cr.r.ReadString('\n')
if err != nil {
return 0, err
}
line = strings.TrimSpace(line)
// Strip any chunk extensions
if idx := strings.Index(line, ";"); idx >= 0 {
line = line[:idx]
}
size, err := strconv.ParseInt(line, 16, 64)
if err != nil {
return 0, fmt.Errorf("parse chunk size %q: %w", line, err)
}
if size == 0 {
cr.done = true
// Consume trailing \r\n
_, _ = cr.r.ReadString('\n')
return 0, io.EOF
}
if size < 0 || size > icapMaxRespSize {
return 0, fmt.Errorf("chunk size %d out of range (max %d)", size, icapMaxRespSize)
}
cr.remaining = int(size)
}
toRead := len(p)
if toRead > cr.remaining {
toRead = cr.remaining
}
n, err := cr.r.Read(p[:toRead])
cr.remaining -= n
if cr.remaining == 0 {
// Consume chunk-terminating \r\n
_, _ = cr.r.ReadString('\n')
}
return n, err
}

View File

@@ -1,21 +0,0 @@
//go:build !linux
package inspect
import (
"fmt"
"net"
"net/netip"
log "github.com/sirupsen/logrus"
)
// newTPROXYListener is not supported on non-Linux platforms.
func newTPROXYListener(_ *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
return nil, fmt.Errorf("TPROXY listener not supported on this platform (requested %s)", addr)
}
// getOriginalDst is not supported on non-Linux platforms.
func getOriginalDst(_ net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, fmt.Errorf("SO_ORIGINAL_DST not supported on this platform")
}

View File

@@ -1,89 +0,0 @@
package inspect
import (
"fmt"
"net"
"net/netip"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// newTPROXYListener creates a TCP listener for the transparent proxy.
// After nftables REDIRECT, accepted connections have LocalAddr = WG_IP:proxy_port.
// The original destination is retrieved via getsockopt(SO_ORIGINAL_DST).
func newTPROXYListener(logger *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
ln, err := net.Listen("tcp", addr.String())
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", addr, err)
}
logger.Infof("inspect: listener started on %s", ln.Addr())
return ln, nil
}
// getOriginalDst reads the original destination from conntrack via SO_ORIGINAL_DST.
// This is set by the kernel when the connection was REDIRECT'd/DNAT'd.
// Tries IPv4 first, then falls back to IPv6 (IP6T_SO_ORIGINAL_DST).
func getOriginalDst(conn net.Conn) (netip.AddrPort, error) {
tc, ok := conn.(*net.TCPConn)
if !ok {
return netip.AddrPort{}, fmt.Errorf("not a TCPConn")
}
raw, err := tc.SyscallConn()
if err != nil {
return netip.AddrPort{}, fmt.Errorf("get syscall conn: %w", err)
}
var origDst netip.AddrPort
var sockErr error
if err := raw.Control(func(fd uintptr) {
// Try IPv4 first (SO_ORIGINAL_DST = 80)
var sa4 unix.RawSockaddrInet4
sa4Len := uint32(unsafe.Sizeof(sa4))
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IP,
80, // SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa4)),
uintptr(unsafe.Pointer(&sa4Len)),
0,
)
if errno == 0 {
addr := netip.AddrFrom4(sa4.Addr)
port := uint16(sa4.Port>>8) | uint16(sa4.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
return
}
// Fall back to IPv6 (IP6T_SO_ORIGINAL_DST = 80 on SOL_IPV6)
var sa6 unix.RawSockaddrInet6
sa6Len := uint32(unsafe.Sizeof(sa6))
_, _, errno = unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IPV6,
80, // IP6T_SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa6)),
uintptr(unsafe.Pointer(&sa6Len)),
0,
)
if errno != 0 {
sockErr = fmt.Errorf("getsockopt SO_ORIGINAL_DST (v4 and v6): %w", errno)
return
}
addr := netip.AddrFrom16(sa6.Addr)
port := uint16(sa6.Port>>8) | uint16(sa6.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
}); err != nil {
return netip.AddrPort{}, fmt.Errorf("control raw conn: %w", err)
}
if sockErr != nil {
return netip.AddrPort{}, sockErr
}
return origDst, nil
}

View File

@@ -1,200 +0,0 @@
package inspect
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
mrand "math/rand/v2"
"sync"
"time"
)
const (
// certCacheSize is the maximum number of cached leaf certificates.
certCacheSize = 1024
// certTTL is how long generated certificates remain valid.
certTTL = 24 * time.Hour
)
// certCache is a bounded LRU cache for generated TLS certificates.
type certCache struct {
mu sync.Mutex
entries map[string]*certEntry
// order tracks LRU eviction, most recent at end.
order []string
maxSize int
}
type certEntry struct {
cert *tls.Certificate
expiresAt time.Time
}
func newCertCache(maxSize int) *certCache {
return &certCache{
entries: make(map[string]*certEntry, maxSize),
order: make([]string, 0, maxSize),
maxSize: maxSize,
}
}
func (c *certCache) get(hostname string) (*tls.Certificate, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.entries[hostname]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
c.removeLocked(hostname)
return nil, false
}
// Move to end (most recently used)
c.touchLocked(hostname)
return entry.cert, true
}
func (c *certCache) put(hostname string, cert *tls.Certificate) {
c.mu.Lock()
defer c.mu.Unlock()
// Jitter the TTL by +/- 20% to prevent thundering herd on expiry.
jitter := time.Duration(float64(certTTL) * (0.8 + 0.4*mrand.Float64()))
if _, exists := c.entries[hostname]; exists {
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.touchLocked(hostname)
return
}
// Evict oldest if at capacity
for len(c.entries) >= c.maxSize && len(c.order) > 0 {
c.removeLocked(c.order[0])
}
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.order = append(c.order, hostname)
}
func (c *certCache) touchLocked(hostname string) {
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
c.order = append(c.order, hostname)
return
}
}
}
func (c *certCache) removeLocked(hostname string) {
delete(c.entries, hostname)
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
// CertProvider generates TLS certificates on the fly, signed by a CA.
// Generated certificates are cached in an LRU cache.
type CertProvider struct {
ca *x509.Certificate
caKey crypto.PrivateKey
cache *certCache
}
// NewCertProvider creates a certificate provider using the given CA.
func NewCertProvider(ca *x509.Certificate, caKey crypto.PrivateKey) *CertProvider {
return &CertProvider{
ca: ca,
caKey: caKey,
cache: newCertCache(certCacheSize),
}
}
// GetCertificate returns a TLS certificate for the given hostname,
// generating and caching one if necessary.
func (p *CertProvider) GetCertificate(hostname string) (*tls.Certificate, error) {
if cert, ok := p.cache.get(hostname); ok {
return cert, nil
}
cert, err := p.generateCert(hostname)
if err != nil {
return nil, fmt.Errorf("generate cert for %s: %w", hostname, err)
}
p.cache.put(hostname, cert)
return cert, nil
}
// GetTLSConfig returns a tls.Config that dynamically provides certificates
// for any hostname using the MITM CA.
func (p *CertProvider) GetTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.GetCertificate(hello.ServerName)
},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
}
func (p *CertProvider) generateCert(hostname string) (*tls.Certificate, error) {
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: hostname,
},
NotBefore: now.Add(-5 * time.Minute),
NotAfter: now.Add(certTTL),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
DNSNames: []string{hostname},
}
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate leaf key: %w", err)
}
certDER, err := x509.CreateCertificate(rand.Reader, template, p.ca, &leafKey.PublicKey, p.caKey)
if err != nil {
return nil, fmt.Errorf("sign leaf certificate: %w", err)
}
leafCert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated certificate: %w", err)
}
return &tls.Certificate{
Certificate: [][]byte{certDER, p.ca.Raw},
PrivateKey: leafKey,
Leaf: leafCert,
}, nil
}

View File

@@ -1,133 +0,0 @@
package inspect
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert, key
}
func TestCertProvider_GetCertificate(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert, err := provider.GetCertificate("example.com")
require.NoError(t, err)
require.NotNil(t, cert)
// Verify the leaf certificate
assert.Equal(t, "example.com", cert.Leaf.Subject.CommonName)
assert.Contains(t, cert.Leaf.DNSNames, "example.com")
// Verify chain: leaf + CA
assert.Len(t, cert.Certificate, 2)
// Verify leaf is signed by our CA
pool := x509.NewCertPool()
pool.AddCert(ca)
_, err = cert.Leaf.Verify(x509.VerifyOptions{
Roots: pool,
})
require.NoError(t, err)
}
func TestCertProvider_CachesResults(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
// Same pointer = cached
assert.Equal(t, cert1, cert2)
}
func TestCertProvider_DifferentHostsDifferentCerts(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("a.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("b.example.com")
require.NoError(t, err)
assert.NotEqual(t, cert1.Leaf.SerialNumber, cert2.Leaf.SerialNumber)
}
func TestCertProvider_TLSConfigHandshake(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
tlsConfig := provider.GetTLSConfig()
require.NotNil(t, tlsConfig)
require.NotNil(t, tlsConfig.GetCertificate)
// Simulate a ClientHelloInfo
hello := &tls.ClientHelloInfo{
ServerName: "handshake.example.com",
}
cert, err := tlsConfig.GetCertificate(hello)
require.NoError(t, err)
assert.Equal(t, "handshake.example.com", cert.Leaf.Subject.CommonName)
}
func TestCertCache_Eviction(t *testing.T) {
cache := newCertCache(3)
for i := range 5 {
hostname := string(rune('a'+i)) + ".example.com"
cache.put(hostname, &tls.Certificate{})
}
// Only 3 should remain (c, d, e - the most recent)
assert.Len(t, cache.entries, 3)
_, ok := cache.get("a.example.com")
assert.False(t, ok, "oldest entry should be evicted")
_, ok = cache.get("b.example.com")
assert.False(t, ok, "second oldest should be evicted")
_, ok = cache.get("e.example.com")
assert.True(t, ok, "newest entry should exist")
}

View File

@@ -1,109 +0,0 @@
package inspect
import (
"bytes"
"fmt"
"io"
"net"
)
// peekConn wraps a net.Conn with a buffer that allows reading ahead
// without consuming data. Subsequent Read calls return the buffered
// bytes first, then read from the underlying connection.
type peekConn struct {
net.Conn
buf bytes.Buffer
// peeked holds the raw bytes that were peeked, available for replay.
peeked []byte
}
// newPeekConn wraps conn for peek-ahead reading.
func newPeekConn(conn net.Conn) *peekConn {
return &peekConn{Conn: conn}
}
// Peek reads exactly n bytes from the connection without consuming them.
// The peeked bytes are replayed on subsequent Read calls.
// Peek may only be called once; calling it again returns an error.
func (c *peekConn) Peek(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
if _, err := io.ReadFull(c.Conn, buf); err != nil {
return nil, fmt.Errorf("peek %d bytes: %w", n, err)
}
c.peeked = buf
c.buf.Write(buf)
return buf, nil
}
// PeekAll reads up to n bytes, returning whatever is available.
// Unlike Peek, it does not require exactly n bytes.
func (c *peekConn) PeekAll(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
nr, err := c.Conn.Read(buf)
if nr > 0 {
c.peeked = buf[:nr]
c.buf.Write(c.peeked)
}
if err != nil && nr == 0 {
return nil, fmt.Errorf("peek: %w", err)
}
return c.peeked, nil
}
// PeekMore extends the peeked buffer to at least n total bytes.
// The buffer is reset and refilled with the extended data.
// The returned slice is the internal peeked buffer; callers must not
// retain references from prior Peek/PeekMore calls after calling this.
func (c *peekConn) PeekMore(n int) ([]byte, error) {
if len(c.peeked) >= n {
return c.peeked[:n], nil
}
remaining := n - len(c.peeked)
extra := make([]byte, remaining)
if _, err := io.ReadFull(c.Conn, extra); err != nil {
return nil, fmt.Errorf("peek more %d bytes: %w", remaining, err)
}
// Pre-allocate to avoid reallocation detaching previously returned slices.
combined := make([]byte, 0, n)
combined = append(combined, c.peeked...)
combined = append(combined, extra...)
c.peeked = combined
c.buf.Reset()
c.buf.Write(c.peeked)
return c.peeked, nil
}
// Peeked returns the bytes that were peeked so far, or nil if Peek hasn't been called.
func (c *peekConn) Peeked() []byte {
return c.peeked
}
// Read returns buffered peek data first, then reads from the underlying connection.
func (c *peekConn) Read(p []byte) (int, error) {
if c.buf.Len() > 0 {
return c.buf.Read(p)
}
return c.Conn.Read(p)
}
// reader returns an io.Reader that replays buffered bytes then reads from conn.
func (c *peekConn) reader() io.Reader {
if c.buf.Len() > 0 {
return io.MultiReader(&c.buf, c.Conn)
}
return c.Conn
}

View File

@@ -1,482 +0,0 @@
package inspect
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ErrBlocked is returned when a connection is denied by proxy policy.
var ErrBlocked = errors.New("connection blocked by proxy policy")
const (
// headerReadTimeout is the deadline for reading the initial protocol header.
// Prevents slow loris attacks where a client opens a connection but sends data slowly.
headerReadTimeout = 10 * time.Second
// idleTimeout is the deadline for idle connections between HTTP requests.
idleTimeout = 120 * time.Second
)
// Proxy is the inspection engine for traffic passing through a NetBird
// routing peer. It handles protocol detection, rule evaluation, MITM TLS
// decryption, ICAP delegation, and external proxy forwarding.
type Proxy struct {
config Config
rules *RuleEngine
certs *CertProvider
icap *ICAPClient
// envoy is nil unless mode is ModeEnvoy.
envoy *envoyManager
// dialer is the outbound dialer (with SO_MARK cleared on Linux).
dialer net.Dialer
log *log.Entry
// wgNetwork is the WG overlay prefix; dial targets inside it are blocked.
wgNetwork netip.Prefix
// localIPs reports the routing peer's own IPs; dial targets are blocked.
localIPs LocalIPChecker
// listener is the TPROXY/REDIRECT listener for kernel mode.
listener net.Listener
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// LocalIPChecker reports whether an IP belongs to the local machine.
type LocalIPChecker interface {
IsLocalIP(netip.Addr) bool
}
// New creates a transparent proxy with the given configuration.
func New(ctx context.Context, logger *log.Entry, config Config) (*Proxy, error) {
ctx, cancel := context.WithCancel(ctx)
p := &Proxy{
config: config,
rules: NewRuleEngine(logger, config.DefaultAction),
dialer: newOutboundDialer(),
log: logger,
wgNetwork: config.WGNetwork,
localIPs: config.LocalIPChecker,
ctx: ctx,
cancel: cancel,
}
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Initialize MITM certificate provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
}
// Initialize ICAP client
if config.ICAP != nil {
p.icap = NewICAPClient(logger, config.ICAP)
}
// Start envoy sidecar if configured
if config.Mode == ModeEnvoy {
envoyLog := logger.WithField("sidecar", "envoy")
em, err := startEnvoy(ctx, envoyLog, config)
if err != nil {
cancel()
return nil, fmt.Errorf("start envoy sidecar: %w", err)
}
p.envoy = em
}
// Start TPROXY listener for kernel mode
if config.ListenAddr.IsValid() {
ln, err := newTPROXYListener(logger, config.ListenAddr, netip.Prefix{})
if err != nil {
cancel()
return nil, fmt.Errorf("start TPROXY listener on %s: %w", config.ListenAddr, err)
}
p.listener = ln
go p.acceptLoop(ln)
}
return p, nil
}
// HandleTCP is the entry point for TCP connections from the userspace forwarder.
// It determines the protocol (TLS or plaintext HTTP), evaluates rules,
// and either blocks, passes through, inspects, or forwards to an external proxy.
func (p *Proxy) HandleTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) error {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
p.mu.RLock()
mode := p.config.Mode
p.mu.RUnlock()
if mode == ModeExternal {
pconn := newPeekConn(clientConn)
return p.handleExternal(ctx, pconn, dst)
}
// Envoy and builtin modes both peek the protocol header for rule evaluation.
// Envoy mode forwards non-blocked traffic to envoy; builtin mode handles all locally.
// TLS blocks are handled by Go (instant close) since envoy can't cleanly RST a TLS connection.
// Built-in and envoy mode: peek 5 bytes (TLS record header size) to determine protocol.
// Set a read deadline to prevent slow loris attacks.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
return fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
if isTLSHandshake(header[0]) {
return p.handleTLS(ctx, pconn, dst, src)
}
if isHTTPMethod(header) {
return p.handlePlainHTTP(ctx, pconn, dst, src)
}
// Not TLS and not HTTP: evaluate rules with ProtoOther.
// If no rule explicitly allows "other", this falls through to the default action.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial for passthrough: %w", err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote conn: %v", err)
}
}()
return relay(ctx, pconn, remote)
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
return ErrBlocked
}
// InspectTCP evaluates rules for a TCP connection and returns the result.
// Unlike HandleTCP, it can return early for allow decisions, letting the caller
// handle the relay (USP forwarder passthrough optimization).
//
// When InspectResult.PassthroughConn is non-nil, ownership transfers to the caller:
// the caller must close the connection and relay traffic. The engine does not close it.
//
// When PassthroughConn is nil, the engine handled everything internally
// (block, inspect/MITM, or plain HTTP inspection) and closed the connection.
func (p *Proxy) InspectTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
p.mu.RLock()
mode := p.config.Mode
envoy := p.envoy
p.mu.RUnlock()
// External mode: handle internally, engine owns the connection.
if mode == ModeExternal {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
pconn := newPeekConn(clientConn)
err := p.handleExternal(ctx, pconn, dst)
return InspectResult{Action: ActionAllow}, err
}
// Peek protocol header.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("clear read deadline: %w", err)
}
// TLS: may return passthrough for allow.
if isTLSHandshake(header[0]) {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil && result.PassthroughConn == nil {
clientConn.Close()
return result, err
}
// Envoy mode: forward allowed TLS to envoy instead of returning passthrough.
if result.PassthroughConn != nil && envoy != nil {
defer clientConn.Close()
envoyErr := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, envoyErr
}
return result, err
}
// Plain HTTP: in envoy mode, forward to envoy for L7 processing.
// In builtin mode, inspect per-request locally.
if isHTTPMethod(header) {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
if envoy != nil {
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
err := p.handlePlainHTTP(ctx, pconn, dst, src)
return InspectResult{Action: ActionInspect}, err
}
// Other protocol: evaluate rules.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
// Envoy mode: forward to envoy.
if envoy != nil {
defer clientConn.Close()
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
clientConn.Close()
return InspectResult{Action: ActionBlock}, ErrBlocked
}
// HandleUDPPacket inspects a UDP packet for QUIC Initial packets.
// Returns the action to take: ActionAllow to continue normal forwarding,
// ActionBlock to drop the packet.
// Non-QUIC packets always return ActionAllow.
func (p *Proxy) HandleUDPPacket(data []byte, dst netip.AddrPort, src SourceInfo) Action {
if len(data) < 5 {
return ActionAllow
}
// Check for QUIC Long Header
if data[0]&0x80 == 0 {
return ActionAllow
}
sni, err := ExtractQUICSNI(data)
if err != nil {
// Can't parse QUIC, allow through (could be non-QUIC UDP)
p.log.Tracef("QUIC SNI extraction failed for %s: %v", dst, err)
return ActionAllow
}
if sni == "" {
return ActionAllow
}
action := p.rules.Evaluate(src.IP, sni, dst.Addr(), dst.Port(), ProtoH3, "")
if action == ActionBlock {
p.log.Debugf("block: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
return ActionBlock
}
// QUIC can't be MITMed, treat Inspect as Allow
if action == ActionInspect {
p.log.Debugf("allow: QUIC to %s (SNI=%s), MITM not supported for QUIC", dst, sni.PunycodeString())
} else {
p.log.Tracef("allow: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
}
return ActionAllow
}
// handlePlainHTTP handles plaintext HTTP connections.
func (p *Proxy) handlePlainHTTP(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
// For plaintext HTTP, always inspect (we can see the traffic)
return p.inspectHTTP(ctx, pconn, remote, dst, "", src, "http/1.1")
}
// UpdateConfig replaces the inspection engine configuration at runtime.
func (p *Proxy) UpdateConfig(config Config) {
p.log.Debugf("config update: mode=%s rules=%d default=%s has_tls=%v has_icap=%v",
config.Mode, len(config.Rules), config.DefaultAction, config.TLS != nil, config.ICAP != nil)
p.mu.Lock()
p.config = config
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Update MITM provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
} else {
p.certs = nil
}
// Swap ICAP client under lock, close the old one outside to avoid blocking.
var oldICAP *ICAPClient
if config.ICAP != nil {
oldICAP = p.icap
p.icap = NewICAPClient(p.log, config.ICAP)
} else {
oldICAP = p.icap
p.icap = nil
}
// If switching away from envoy mode, clear and stop the old envoy.
var oldEnvoy *envoyManager
if config.Mode != ModeEnvoy && p.envoy != nil {
oldEnvoy = p.envoy
p.envoy = nil
}
envoy := p.envoy
p.mu.Unlock()
if oldICAP != nil {
oldICAP.Close()
}
if oldEnvoy != nil {
oldEnvoy.Stop()
}
// Reload envoy config if still in envoy mode.
if envoy != nil && config.Mode == ModeEnvoy {
if err := envoy.Reload(config); err != nil {
p.log.Errorf("inspect: envoy config reload: %v", err)
}
}
}
// Mode returns the current proxy operating mode.
func (p *Proxy) Mode() ProxyMode {
p.mu.RLock()
defer p.mu.RUnlock()
return p.config.Mode
}
// ListenPort returns the port to use for kernel-mode nftables REDIRECT.
// For builtin mode: the TPROXY listener port.
// For envoy mode: the envoy listener port (nftables redirects directly to envoy).
// Returns 0 if no listener is active.
func (p *Proxy) ListenPort() uint16 {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return envoy.listenPort
}
if p.listener == nil {
return 0
}
tcpAddr, ok := p.listener.Addr().(*net.TCPAddr)
if !ok {
return 0
}
return uint16(tcpAddr.Port)
}
// Close shuts down the proxy and releases resources.
func (p *Proxy) Close() error {
p.cancel()
p.mu.Lock()
envoy := p.envoy
p.envoy = nil
icap := p.icap
p.icap = nil
p.mu.Unlock()
if envoy != nil {
envoy.Stop()
}
if p.listener != nil {
if err := p.listener.Close(); err != nil {
p.log.Debugf("close TPROXY listener: %v", err)
}
}
if icap != nil {
icap.Close()
}
return nil
}
// acceptLoop accepts connections from the redirected listener (kernel mode).
// Connections arrive via nftables REDIRECT; original destination is read from conntrack.
func (p *Proxy) acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
if p.ctx.Err() != nil {
return
}
p.log.Debugf("accept error: %v", err)
continue
}
go func() {
// Read original destination from conntrack (SO_ORIGINAL_DST).
// nftables REDIRECT changes dst to the local WG IP:proxy_port,
// but conntrack preserves the real destination.
dstAddr, err := getOriginalDst(conn)
if err != nil {
p.log.Debugf("get original dst: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
p.log.Tracef("accepted: %s -> %s (original dst %s)",
conn.RemoteAddr(), conn.LocalAddr(), dstAddr)
srcAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
p.log.Debugf("parse source: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
src := SourceInfo{
IP: srcAddr.Addr().Unmap(),
}
if err := p.HandleTCP(p.ctx, conn, dstAddr, src); err != nil && !errors.Is(err, ErrBlocked) {
p.log.Debugf("connection to %s: %v", dstAddr, err)
}
}()
}
}

View File

@@ -1,388 +0,0 @@
package inspect
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
"github.com/netbirdio/netbird/shared/management/domain"
)
// QUIC version constants
const (
quicV1Version uint32 = 0x00000001
quicV2Version uint32 = 0x6b3343cf
)
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
var quicV1Salt = []byte{
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
0xcc, 0xbb, 0x7f, 0x0a,
}
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
var quicV2Salt = []byte{
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
0xf9, 0xbd, 0x2e, 0xd9,
}
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
// The Initial packet's encryption uses well-known keys derived from the
// Destination Connection ID, so any observer can decrypt it (by design).
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
if len(data) < 5 {
return "", fmt.Errorf("packet too short")
}
// Check for QUIC Long Header (form bit set)
if data[0]&0x80 == 0 {
return "", fmt.Errorf("not a QUIC long header packet")
}
// Version
version := binary.BigEndian.Uint32(data[1:5])
var salt []byte
var initialLabel, keyLabel, ivLabel, hpLabel string
switch version {
case quicV1Version:
salt = quicV1Salt
initialLabel = "client in"
keyLabel = "quic key"
ivLabel = "quic iv"
hpLabel = "quic hp"
case quicV2Version:
salt = quicV2Salt
initialLabel = "client in"
keyLabel = "quicv2 key"
ivLabel = "quicv2 iv"
hpLabel = "quicv2 hp"
default:
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
}
// Parse Long Header
if len(data) < 6 {
return "", fmt.Errorf("packet too short for DCID length")
}
dcidLen := int(data[5])
if len(data) < 6+dcidLen+1 {
return "", fmt.Errorf("packet too short for DCID")
}
dcid := data[6 : 6+dcidLen]
scidLenOff := 6 + dcidLen
scidLen := int(data[scidLenOff])
tokenLenOff := scidLenOff + 1 + scidLen
if tokenLenOff >= len(data) {
return "", fmt.Errorf("packet too short for token length")
}
// Token length is a variable-length integer
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
if err != nil {
return "", fmt.Errorf("read token length: %w", err)
}
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
if payloadLenOff >= len(data) {
return "", fmt.Errorf("packet too short for payload length")
}
// Payload length is a variable-length integer
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
if err != nil {
return "", fmt.Errorf("read payload length: %w", err)
}
pnOffset := payloadLenOff + payloadLenSize
if pnOffset+4 > len(data) {
return "", fmt.Errorf("packet too short for packet number")
}
// Derive initial keys
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
if err != nil {
return "", fmt.Errorf("derive initial keys: %w", err)
}
// Remove header protection
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
if sampleOffset+16 > len(data) {
return "", fmt.Errorf("packet too short for HP sample")
}
sample := data[sampleOffset : sampleOffset+16]
hpBlock, err := aes.NewCipher(clientHP)
if err != nil {
return "", fmt.Errorf("create HP cipher: %w", err)
}
mask := make([]byte, 16)
hpBlock.Encrypt(mask, sample)
// Unmask header byte
header := make([]byte, len(data))
copy(header, data)
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
// Determine packet number length
pnLen := int(header[0]&0x03) + 1
// Unmask packet number
for i := 0; i < pnLen; i++ {
header[pnOffset+i] ^= mask[1+i]
}
// Reconstruct packet number
var pn uint32
for i := 0; i < pnLen; i++ {
pn = (pn << 8) | uint32(header[pnOffset+i])
}
// Build nonce
nonce := make([]byte, len(clientIV))
copy(nonce, clientIV)
for i := 0; i < 4; i++ {
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
}
// Decrypt payload
block, err := aes.NewCipher(clientKey)
if err != nil {
return "", fmt.Errorf("create AES cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create AEAD: %w", err)
}
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
aad := header[:pnOffset+pnLen]
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
if err != nil {
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
}
// Parse CRYPTO frames to extract ClientHello
clientHello, err := extractCryptoFrames(plaintext)
if err != nil {
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
}
info, err := parseHelloBody(clientHello)
return info.SNI, err
}
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
// initial_secret = HKDF-Extract(salt, DCID)
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
}
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
}
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
}
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
}
return key, iv, hp, nil
}
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
// HkdfLabel = struct {
// uint16 length;
// opaque label<7..255> = "tls13 " + Label;
// opaque context<0..255> = Context;
// }
fullLabel := "tls13 " + label
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
hkdfLabel[2] = byte(len(fullLabel))
copy(hkdfLabel[3:], fullLabel)
hkdfLabel[3+len(fullLabel)] = byte(len(context))
if len(context) > 0 {
copy(hkdfLabel[4+len(fullLabel):], context)
}
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
out := make([]byte, length)
if _, err := io.ReadFull(expander, out); err != nil {
return nil, err
}
return out, nil
}
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
const maxCryptoFrameSize = 64 * 1024
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
func extractCryptoFrames(frames []byte) ([]byte, error) {
var result []byte
pos := 0
for pos < len(frames) {
frameType := frames[pos]
switch {
case frameType == 0x00:
// PADDING frame
pos++
case frameType == 0x06:
// CRYPTO frame
pos++
offset, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto offset: %w", err)
}
pos += n
_ = offset // We assume ordered, offset 0 for Initial
dataLen, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto data length: %w", err)
}
pos += n
end := pos + int(dataLen)
if end > len(frames) {
return nil, fmt.Errorf("CRYPTO frame data truncated")
}
result = append(result, frames[pos:end]...)
if len(result) > maxCryptoFrameSize {
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
}
pos = end
case frameType == 0x01:
// PING frame
pos++
case frameType == 0x02 || frameType == 0x03:
// ACK frame - skip
pos++
// Largest Acknowledged
_, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK: %w", err)
}
pos += n
// ACK Delay
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK delay: %w", err)
}
pos += n
// ACK Range Count
rangeCount, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range count: %w", err)
}
pos += n
// First ACK Range
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read first ACK range: %w", err)
}
pos += n
// Additional ranges
for i := uint64(0); i < rangeCount; i++ {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK gap: %w", err)
}
pos += n
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range: %w", err)
}
pos += n
}
// ECN counts for type 0x03
if frameType == 0x03 {
for range 3 {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ECN count: %w", err)
}
pos += n
}
}
default:
// Unknown frame type, stop parsing
if len(result) > 0 {
return result, nil
}
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
}
}
if len(result) == 0 {
return nil, fmt.Errorf("no CRYPTO frames found")
}
return result, nil
}
// readVarInt reads a QUIC variable-length integer.
// Returns (value, bytes consumed, error).
func readVarInt(data []byte) (uint64, int, error) {
if len(data) == 0 {
return 0, 0, fmt.Errorf("empty data for varint")
}
prefix := data[0] >> 6
length := 1 << prefix
if len(data) < length {
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
}
var val uint64
switch length {
case 1:
val = uint64(data[0] & 0x3f)
case 2:
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
case 4:
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
case 8:
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
}
return val, length, nil
}

View File

@@ -1,99 +0,0 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadVarInt(t *testing.T) {
tests := []struct {
name string
data []byte
want uint64
n int
}{
{
name: "1 byte value",
data: []byte{0x25},
want: 37,
n: 1,
},
{
name: "2 byte value",
data: []byte{0x7b, 0xbd},
want: 15293,
n: 2,
},
{
name: "4 byte value",
data: []byte{0x9d, 0x7f, 0x3e, 0x7d},
want: 494878333,
n: 4,
},
{
name: "zero",
data: []byte{0x00},
want: 0,
n: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, n, err := readVarInt(tt.data)
require.NoError(t, err)
assert.Equal(t, tt.want, val)
assert.Equal(t, tt.n, n)
})
}
}
func TestReadVarInt_Empty(t *testing.T) {
_, _, err := readVarInt(nil)
require.Error(t, err)
}
func TestReadVarInt_Truncated(t *testing.T) {
// 2-byte prefix but only 1 byte
_, _, err := readVarInt([]byte{0x40})
require.Error(t, err)
}
func TestExtractQUICSNI_NotLongHeader(t *testing.T) {
// Short header packet (form bit not set)
data := make([]byte, 100)
data[0] = 0x40 // short header
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a QUIC long header")
}
func TestExtractQUICSNI_UnsupportedVersion(t *testing.T) {
data := make([]byte, 100)
data[0] = 0xC0 // long header
// Version 0xdeadbeef
data[1] = 0xde
data[2] = 0xad
data[3] = 0xbe
data[4] = 0xef
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported QUIC version")
}
func TestExtractQUICSNI_TooShort(t *testing.T) {
_, err := ExtractQUICSNI([]byte{0xC0, 0x00})
require.Error(t, err)
}
func TestHkdfExpandLabel(t *testing.T) {
// Smoke test: ensure it returns the right length and doesn't error
secret := make([]byte, 32)
result, err := hkdfExpandLabel(secret, "quic key", nil, 16)
require.NoError(t, err)
assert.Len(t, result, 16)
}

View File

@@ -1,253 +0,0 @@
package inspect
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// RuleEngine evaluates proxy rules against connection metadata.
// It is safe for concurrent use.
type RuleEngine struct {
mu sync.RWMutex
rules []Rule
// defaultAction applies when no rule matches.
defaultAction Action
log *log.Entry
}
// NewRuleEngine creates a rule engine with the given default action.
func NewRuleEngine(logger *log.Entry, defaultAction Action) *RuleEngine {
return &RuleEngine{
defaultAction: defaultAction,
log: logger,
}
}
// UpdateRules replaces the rule set and default action. Rules are sorted by priority.
func (e *RuleEngine) UpdateRules(rules []Rule, defaultAction Action) {
sorted := make([]Rule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
e.mu.Lock()
e.rules = sorted
e.defaultAction = defaultAction
e.mu.Unlock()
}
// EvalResult holds the outcome of a rule evaluation.
type EvalResult struct {
Action Action
RuleID id.RuleID
}
// Evaluate determines the action for a connection based on the rule set.
// Pass empty path for connection-level evaluation (TLS/SNI), non-empty for request-level (HTTP).
func (e *RuleEngine) Evaluate(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) Action {
r := e.EvaluateWithResult(src, dstDomain, dstAddr, dstPort, proto, path)
return r.Action
}
// EvaluateWithResult is like Evaluate but also returns the matched rule ID.
func (e *RuleEngine) EvaluateWithResult(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) EvalResult {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
rule := &e.rules[i]
if e.ruleMatches(rule, src, dstDomain, dstAddr, dstPort, proto, path) {
e.log.Tracef("rule %s matched: action=%s src=%s domain=%s dst=%s:%d proto=%s path=%s",
rule.ID, rule.Action, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: rule.Action, RuleID: rule.ID}
}
}
e.log.Tracef("no rule matched, default=%s: src=%s domain=%s dst=%s:%d proto=%s path=%s",
e.defaultAction, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: e.defaultAction}
}
// HasPathRulesForDomain returns true if any rule matching the domain has non-empty Paths.
// Used to force MITM inspection when path-level rules exist (paths are only visible after decryption).
func (e *RuleEngine) HasPathRulesForDomain(dstDomain domain.Domain) bool {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
if len(e.rules[i].Paths) > 0 && e.matchDomain(&e.rules[i], dstDomain) {
return true
}
}
return false
}
// ruleMatches checks whether all non-empty fields of a rule match.
// Empty fields are treated as "match any".
// All specified fields must match (AND logic).
func (e *RuleEngine) ruleMatches(rule *Rule, src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) bool {
if !e.matchSource(rule, src) {
return false
}
if !e.matchDomain(rule, dstDomain) {
return false
}
if !e.matchNetwork(rule, dstAddr) {
return false
}
if !e.matchPort(rule, dstPort) {
return false
}
if !e.matchProtocol(rule, proto) {
return false
}
if !e.matchPaths(rule, path) {
return false
}
return true
}
// matchSource returns true if src matches any of the rule's source CIDRs,
// or if no source CIDRs are specified (match any).
func (e *RuleEngine) matchSource(rule *Rule, src netip.Addr) bool {
if len(rule.Sources) == 0 {
return true
}
for _, prefix := range rule.Sources {
if prefix.Contains(src) {
return true
}
}
return false
}
// matchDomain returns true if dstDomain matches any of the rule's domain patterns,
// or if no domain patterns are specified (match any).
func (e *RuleEngine) matchDomain(rule *Rule, dstDomain domain.Domain) bool {
if len(rule.Domains) == 0 {
return true
}
// If we have domain rules but no domain to match against (e.g., raw IP connection),
// the domain condition does not match.
if dstDomain == "" {
return false
}
for _, pattern := range rule.Domains {
if MatchDomain(pattern, dstDomain) {
return true
}
}
return false
}
// matchNetwork returns true if dstAddr is within any of the rule's destination CIDRs,
// or if no destination CIDRs are specified (match any).
func (e *RuleEngine) matchNetwork(rule *Rule, dstAddr netip.Addr) bool {
if len(rule.Networks) == 0 {
return true
}
for _, prefix := range rule.Networks {
if prefix.Contains(dstAddr) {
return true
}
}
return false
}
// matchProtocol returns true if proto matches any of the rule's protocols,
// or if no protocols are specified (match any).
func (e *RuleEngine) matchProtocol(rule *Rule, proto ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, p := range rule.Protocols {
if p == proto {
return true
}
}
return false
}
// matchPort returns true if dstPort matches any of the rule's destination ports,
// or if no ports are specified (match any).
func (e *RuleEngine) matchPort(rule *Rule, dstPort uint16) bool {
if len(rule.Ports) == 0 {
return true
}
return slices.Contains(rule.Ports, dstPort)
}
// matchPaths returns true if path matches any of the rule's path patterns,
// or if no paths are specified (match any). Empty path (connection-level eval) matches all.
func (e *RuleEngine) matchPaths(rule *Rule, path string) bool {
if len(rule.Paths) == 0 {
return true
}
// Connection-level (path=""): rules with paths don't match at connection level.
// HasPathRulesForDomain forces the connection to inspect, so paths are
// checked per-request once the HTTP request is visible.
if path == "" {
return false
}
for _, pattern := range rule.Paths {
if matchPath(pattern, path) {
return true
}
}
return false
}
// matchPath checks if a URL path matches a pattern.
// Supports: exact ("/login"), prefix with wildcard ("/api/*"),
// and contains ("*/admin/*"). A bare "*" matches everything.
func matchPath(pattern, path string) bool {
if pattern == "*" {
return true
}
hasLeadingStar := strings.HasPrefix(pattern, "*")
hasTrailingStar := strings.HasSuffix(pattern, "*")
switch {
case hasLeadingStar && hasTrailingStar:
// */admin/* = contains
middle := strings.Trim(pattern, "*")
return strings.Contains(path, middle)
case hasTrailingStar:
// /api/* = prefix
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(path, prefix)
case hasLeadingStar:
// *.json = suffix
suffix := strings.TrimPrefix(pattern, "*")
return strings.HasSuffix(path, suffix)
default:
// exact
return path == pattern
}
}

View File

@@ -1,338 +0,0 @@
package inspect
import (
"net/netip"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
func testLogger() *log.Entry {
return log.WithField("test", true)
}
func mustDomain(t *testing.T, s string) domain.Domain {
t.Helper()
d, err := domain.FromString(s)
require.NoError(t, err)
return d
}
func TestRuleEngine_Evaluate(t *testing.T) {
tests := []struct {
name string
rules []Rule
defaultAction Action
src netip.Addr
dstDomain domain.Domain
dstAddr netip.Addr
dstPort uint16
want Action
}{
{
name: "no rules returns default allow",
defaultAction: ActionAllow,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "no rules returns default block",
defaultAction: ActionBlock,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain exact match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "malware.example.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "phishing.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard does not match base",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "case insensitive domain match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "EXAMPLE.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "source CIDR match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.50"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "source CIDR no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.5"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "destination network match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.1"),
dstAddr: netip.MustParseAddr("10.50.0.1"),
dstPort: 80,
want: ActionInspect,
},
{
name: "port match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "port no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 22,
want: ActionAllow,
},
{
name: "priority ordering first match wins",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("allow-internal"),
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
Action: ActionAllow,
Priority: 1,
},
{
ID: id.RuleID("inspect-all"),
Action: ActionInspect,
Priority: 10,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "api.internal.corp"),
dstAddr: netip.MustParseAddr("10.1.0.5"),
dstPort: 443,
want: ActionAllow,
},
{
name: "all fields must match (AND logic)",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Ports: []uint16{443},
Action: ActionBlock,
},
},
// Source matches, domain matches, but port doesn't
src: netip.MustParseAddr("192.168.1.10"),
dstDomain: mustDomain(t, "phish.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 8080,
want: ActionAllow,
},
{
name: "empty domain with domain rule does not match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: "", // raw IP connection, no SNI
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewRuleEngine(testLogger(), tt.defaultAction)
engine.UpdateRules(tt.rules, tt.defaultAction)
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
assert.Equal(t, tt.want, got)
})
}
}
func TestRuleEngine_ProtocolMatching(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-websocket",
Protocols: []ProtoType{ProtoWebSocket},
Action: ActionBlock,
Priority: 1,
},
{
ID: "inspect-h2",
Protocols: []ProtoType{ProtoH2},
Action: ActionInspect,
Priority: 2,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
// WebSocket: blocked by rule
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
// HTTP/2: inspected by rule
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
// Plain HTTP: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
// HTTPS: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
// QUIC/H3: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
// Empty protocol (unknown): no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-all-protos",
Action: ActionBlock,
// No Protocols field = match all protocols
Priority: 1,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{ID: "c", Priority: 30, Action: ActionBlock},
{ID: "a", Priority: 10, Action: ActionInspect},
{ID: "b", Priority: 20, Action: ActionAllow},
}, ActionAllow)
engine.mu.RLock()
defer engine.mu.RUnlock()
require.Len(t, engine.rules, 3)
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
}

View File

@@ -1,287 +0,0 @@
package inspect
import (
"encoding/binary"
"fmt"
"io"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
recordTypeHandshake = 0x16
handshakeTypeClientHello = 0x01
extensionTypeSNI = 0x0000
extensionTypeALPN = 0x0010
sniTypeHostName = 0x00
// maxClientHelloSize is the maximum ClientHello size we'll read.
// Real-world ClientHellos are typically under 1KB but can reach ~16KB with
// many extensions (post-quantum key shares, etc.).
maxClientHelloSize = 16384
)
// ClientHelloInfo holds data extracted from a TLS ClientHello.
type ClientHelloInfo struct {
SNI domain.Domain
ALPN []string
}
// isTLSHandshake reports whether the first byte indicates a TLS handshake record.
func isTLSHandshake(b byte) bool {
return b == recordTypeHandshake
}
// httpMethods lists the first bytes of valid HTTP method tokens.
var httpMethods = [][]byte{
[]byte("GET "),
[]byte("POST"),
[]byte("PUT "),
[]byte("DELE"),
[]byte("HEAD"),
[]byte("OPTI"),
[]byte("PATC"),
[]byte("CONN"),
[]byte("TRAC"),
}
// isHTTPMethod reports whether the peeked bytes look like the start of an HTTP request.
func isHTTPMethod(b []byte) bool {
if len(b) < 4 {
return false
}
for _, m := range httpMethods {
if b[0] == m[0] && b[1] == m[1] && b[2] == m[2] && b[3] == m[3] {
return true
}
}
return false
}
// parseClientHello reads a TLS ClientHello from r and returns SNI and ALPN.
func parseClientHello(r io.Reader) (ClientHelloInfo, error) {
// TLS record header: type(1) + version(2) + length(2)
var recordHeader [5]byte
if _, err := io.ReadFull(r, recordHeader[:]); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read TLS record header: %w", err)
}
if recordHeader[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", recordHeader[0])
}
recordLen := int(binary.BigEndian.Uint16(recordHeader[3:5]))
if recordLen < 4 || recordLen > maxClientHelloSize {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
// Read the full handshake message
msg := make([]byte, recordLen)
if _, err := io.ReadFull(r, msg); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read handshake message: %w", err)
}
return parseClientHelloMsg(msg)
}
// extractSNI reads a TLS ClientHello from r and returns the SNI hostname.
// Returns empty domain if no SNI extension is present.
func extractSNI(r io.Reader) (domain.Domain, error) {
info, err := parseClientHello(r)
return info.SNI, err
}
// extractSNIFromBytes parses SNI from raw bytes that start with the TLS record header.
func extractSNIFromBytes(data []byte) (domain.Domain, error) {
info, err := parseClientHelloFromBytes(data)
return info.SNI, err
}
// parseClientHelloFromBytes parses a ClientHello from raw bytes starting with the TLS record header.
func parseClientHelloFromBytes(data []byte) (ClientHelloInfo, error) {
if len(data) < 5 {
return ClientHelloInfo{}, fmt.Errorf("data too short for TLS record header")
}
if data[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", data[0])
}
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
if recordLen < 4 {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
end := 5 + recordLen
if end > len(data) {
return ClientHelloInfo{}, fmt.Errorf("TLS record truncated: need %d, have %d", end, len(data))
}
return parseClientHelloMsg(data[5:end])
}
// parseClientHelloMsg extracts SNI and ALPN from a raw ClientHello handshake message.
// msg starts at the handshake type byte.
func parseClientHelloMsg(msg []byte) (ClientHelloInfo, error) {
if len(msg) < 4 {
return ClientHelloInfo{}, fmt.Errorf("handshake message too short")
}
if msg[0] != handshakeTypeClientHello {
return ClientHelloInfo{}, fmt.Errorf("not a ClientHello (type=%d)", msg[0])
}
// Handshake header: type(1) + length(3)
helloLen := int(msg[1])<<16 | int(msg[2])<<8 | int(msg[3])
if helloLen+4 > len(msg) {
return ClientHelloInfo{}, fmt.Errorf("ClientHello truncated")
}
hello := msg[4 : 4+helloLen]
return parseHelloBody(hello)
}
// parseHelloBody parses the ClientHello body (after handshake header)
// and extracts SNI and ALPN.
func parseHelloBody(hello []byte) (ClientHelloInfo, error) {
// ClientHello structure:
// version(2) + random(32) + session_id_len(1) + session_id(var)
// + cipher_suites_len(2) + cipher_suites(var)
// + compression_len(1) + compression(var)
// + extensions_len(2) + extensions(var)
var info ClientHelloInfo
if len(hello) < 35 {
return info, fmt.Errorf("ClientHello body too short")
}
pos := 2 + 32 // skip version + random
// Skip session ID
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at session ID")
}
sessionIDLen := int(hello[pos])
pos += 1 + sessionIDLen
// Skip cipher suites
if pos+2 > len(hello) {
return info, fmt.Errorf("ClientHello truncated at cipher suites")
}
cipherLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2 + cipherLen
// Skip compression methods
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at compression")
}
compLen := int(hello[pos])
pos += 1 + compLen
// Extensions
if pos+2 > len(hello) {
return info, nil
}
extLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2
extEnd := pos + extLen
if extEnd > len(hello) {
return info, fmt.Errorf("extensions block truncated")
}
// Walk extensions looking for SNI and ALPN
for pos+4 <= extEnd {
extType := binary.BigEndian.Uint16(hello[pos : pos+2])
extDataLen := int(binary.BigEndian.Uint16(hello[pos+2 : pos+4]))
pos += 4
if pos+extDataLen > extEnd {
return info, fmt.Errorf("extension data truncated")
}
switch extType {
case extensionTypeSNI:
sni, err := parseSNIExtension(hello[pos : pos+extDataLen])
if err != nil {
return info, err
}
info.SNI = sni
case extensionTypeALPN:
info.ALPN = parseALPNExtension(hello[pos : pos+extDataLen])
}
pos += extDataLen
}
return info, nil
}
// parseALPNExtension parses the ALPN extension data and returns protocol names.
// ALPN extension: list_length(2) + entries (each: len(1) + protocol_name(var))
func parseALPNExtension(data []byte) []string {
if len(data) < 2 {
return nil
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return nil
}
var protocols []string
pos := 2
end := 2 + listLen
for pos < end {
if pos >= len(data) {
break
}
nameLen := int(data[pos])
pos++
if pos+nameLen > end {
break
}
protocols = append(protocols, string(data[pos:pos+nameLen]))
pos += nameLen
}
return protocols
}
// parseSNIExtension parses the SNI extension data and returns the hostname.
func parseSNIExtension(data []byte) (domain.Domain, error) {
// SNI extension: list_length(2) + entries
if len(data) < 2 {
return "", fmt.Errorf("SNI extension too short")
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return "", fmt.Errorf("SNI list truncated")
}
pos := 2
end := 2 + listLen
for pos+3 <= end {
nameType := data[pos]
nameLen := int(binary.BigEndian.Uint16(data[pos+1 : pos+3]))
pos += 3
if pos+nameLen > end {
return "", fmt.Errorf("SNI name truncated")
}
if nameType == sniTypeHostName {
hostname := string(data[pos : pos+nameLen])
return domain.FromString(hostname)
}
pos += nameLen
}
return "", nil
}

View File

@@ -1,109 +0,0 @@
package inspect
import (
"bytes"
"crypto/tls"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractSNI(t *testing.T) {
tests := []struct {
name string
sni string
wantSNI string
wantErr bool
}{
{
name: "standard domain",
sni: "example.com",
wantSNI: "example.com",
},
{
name: "subdomain",
sni: "api.staging.example.com",
wantSNI: "api.staging.example.com",
},
{
name: "mixed case normalized to lowercase",
sni: "Example.COM",
wantSNI: "example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientHello := buildClientHello(t, tt.sni)
sni, err := extractSNI(bytes.NewReader(clientHello))
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantSNI, sni.PunycodeString())
})
}
}
func TestExtractSNI_NotTLS(t *testing.T) {
// HTTP request instead of TLS
data := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
assert.Contains(t, err.Error(), "not a TLS handshake")
}
func TestExtractSNI_Truncated(t *testing.T) {
// Just the record header, no body
data := []byte{0x16, 0x03, 0x01, 0x00, 0x05}
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
}
func TestExtractSNIFromBytes(t *testing.T) {
clientHello := buildClientHello(t, "test.example.com")
sni, err := extractSNIFromBytes(clientHello)
require.NoError(t, err)
assert.Equal(t, "test.example.com", sni.PunycodeString())
}
// buildClientHello generates a real TLS ClientHello with the given SNI.
func buildClientHello(t *testing.T, serverName string) []byte {
t.Helper()
// Use a pipe to capture the ClientHello bytes
clientConn, serverConn := net.Pipe()
done := make(chan []byte, 1)
go func() {
buf := make([]byte, 4096)
n, _ := serverConn.Read(buf)
done <- buf[:n]
serverConn.Close()
}()
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
})
// Trigger the handshake (will fail since server isn't TLS, but we capture the ClientHello)
go func() {
_ = tlsConn.Handshake()
tlsConn.Close()
}()
clientHello := <-done
clientConn.Close()
require.True(t, len(clientHello) > 5, "ClientHello too short")
require.Equal(t, byte(0x16), clientHello[0], "not a TLS handshake record")
return clientHello
}

View File

@@ -1,287 +0,0 @@
package inspect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/netip"
"github.com/netbirdio/netbird/shared/management/domain"
)
// handleTLS processes a TLS connection for the kernel-mode path: extracts SNI,
// evaluates rules, and handles the connection internally.
// In envoy mode, allowed connections are forwarded to envoy instead of direct relay.
func (p *Proxy) handleTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil {
return err
}
if result.PassthroughConn != nil {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
}
return p.tlsPassthrough(ctx, pconn, dst, "")
}
return nil
}
// inspectTLS extracts SNI, evaluates rules, and returns the result.
// For ActionAllow: returns the peekConn as PassthroughConn (caller relays).
// For ActionBlock/ActionInspect: handles internally and returns nil PassthroughConn.
func (p *Proxy) inspectTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
// The first 5 bytes (TLS record header) are already peeked.
// Extend to read the full TLS record so bytes remain in the buffer for passthrough.
peeked := pconn.Peeked()
recordLen := int(peeked[3])<<8 | int(peeked[4])
if _, err := pconn.PeekMore(5 + recordLen); err != nil {
return InspectResult{}, fmt.Errorf("read TLS record: %w", err)
}
hello, err := parseClientHelloFromBytes(pconn.Peeked())
if err != nil {
return InspectResult{}, fmt.Errorf("parse ClientHello: %w", err)
}
sni := hello.SNI
proto := protoFromALPN(hello.ALPN)
// Connection-level evaluation: pass empty path.
action := p.evaluateAction(src.IP, sni, dst, proto, "")
// If any rule for this domain has path patterns, force inspect so paths can
// be checked per-request after MITM decryption.
if action == ActionAllow && p.rules.HasPathRulesForDomain(sni) {
p.log.Debugf("upgrading to inspect for %s (path rules exist)", sni.PunycodeString())
action = ActionInspect
}
// Snapshot cert provider under lock for use in this connection.
p.mu.RLock()
certs := p.certs
p.mu.RUnlock()
switch action {
case ActionBlock:
p.log.Debugf("block: TLS to %s (SNI=%s)", dst, sni.PunycodeString())
if certs != nil {
return InspectResult{Action: ActionBlock}, p.tlsBlockPage(ctx, pconn, sni, certs)
}
return InspectResult{Action: ActionBlock}, ErrBlocked
case ActionAllow:
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
case ActionInspect:
if certs == nil {
p.log.Warnf("allow: %s (inspect requested but no MITM CA configured)", sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
err := p.tlsMITM(ctx, pconn, dst, sni, src, certs)
return InspectResult{Action: ActionInspect}, err
default:
p.log.Warnf("block: unknown action %q for %s", action, sni.PunycodeString())
return InspectResult{Action: ActionBlock}, ErrBlocked
}
}
// tlsBlockPage completes a MITM TLS handshake with the client using a dynamic
// certificate, then serves an HTTP 403 block page so the user sees a clear
// message instead of a cryptic SSL error.
func (p *Proxy) tlsBlockPage(ctx context.Context, pconn *peekConn, sni domain.Domain, certs *CertProvider) error {
hostname := sni.PunycodeString()
// Force HTTP/1.1 only: block pages are simple responses, no need for h2
tlsCfg := certs.GetTLSConfig()
tlsCfg.NextProtos = []string{"http/1.1"}
clientTLS := tls.Server(pconn, tlsCfg)
if err := clientTLS.HandshakeContext(ctx); err != nil {
// Client may not trust our CA, handshake fails. That's expected.
return fmt.Errorf("block page TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close block page TLS for %s: %v", hostname, err)
}
}()
writeBlockResponse(clientTLS, nil, sni)
return ErrBlocked
}
// tlsPassthrough connects to the destination and relays encrypted traffic
// without decryption. The peeked ClientHello bytes are replayed.
func (p *Proxy) tlsPassthrough(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return relay(ctx, pconn, remote)
}
// tlsMITM terminates the client TLS connection with a dynamic certificate,
// establishes a new TLS connection to the real destination, and runs the
// HTTP inspection pipeline on the decrypted traffic.
func (p *Proxy) tlsMITM(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, certs *CertProvider) error {
hostname := sni.PunycodeString()
// TLS handshake with client using dynamic cert
clientTLS := tls.Server(pconn, certs.GetTLSConfig())
if err := clientTLS.HandshakeContext(ctx); err != nil {
return fmt.Errorf("client TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close client TLS for %s: %v", hostname, err)
}
}()
// TLS connection to real destination
remoteTLS, err := p.dialTLS(ctx, dst, hostname)
if err != nil {
return fmt.Errorf("dial TLS %s (%s): %w", dst, hostname, err)
}
defer func() {
if err := remoteTLS.Close(); err != nil {
p.log.Debugf("close remote TLS for %s: %v", hostname, err)
}
}()
negotiatedProto := clientTLS.ConnectionState().NegotiatedProtocol
p.log.Tracef("inspect: MITM established for %s (proto=%s)", hostname, negotiatedProto)
return p.inspectHTTP(ctx, clientTLS, remoteTLS, dst, sni, src, negotiatedProto)
}
// dialTLS connects to the destination with TLS, verifying the real server certificate.
func (p *Proxy) dialTLS(ctx context.Context, dst netip.AddrPort, serverName string) (net.Conn, error) {
rawConn, err := p.dialTCP(ctx, dst)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
})
if err := tlsConn.HandshakeContext(ctx); err != nil {
if closeErr := rawConn.Close(); closeErr != nil {
p.log.Debugf("close raw conn after TLS handshake failure: %v", closeErr)
}
return nil, fmt.Errorf("TLS handshake with %s: %w", serverName, err)
}
return tlsConn, nil
}
// protoFromALPN maps TLS ALPN protocol names to proxy ProtoType.
// Falls back to ProtoHTTPS when no recognized ALPN is present.
func protoFromALPN(alpn []string) ProtoType {
for _, p := range alpn {
switch p {
case "h2":
return ProtoH2
case "h3": // unlikely in TLS, but handle anyway
return ProtoH3
}
}
// No ALPN or only "http/1.1": treat as HTTPS
return ProtoHTTPS
}
// relay copies data bidirectionally between client and remote until one
// side closes or the context is cancelled.
func relay(ctx context.Context, client, remote net.Conn) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(remote, client)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(client, remote)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// evaluateAction runs rule evaluation and resolves the effective action.
// Pass empty path for connection-level (TLS), non-empty for request-level (HTTP).
func (p *Proxy) evaluateAction(src netip.Addr, sni domain.Domain, dst netip.AddrPort, proto ProtoType, path string) Action {
return p.rules.Evaluate(src, sni, dst.Addr(), dst.Port(), proto, path)
}
// dialTCP dials the destination, blocking connections to loopback, link-local,
// multicast, and WG overlay network addresses.
func (p *Proxy) dialTCP(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
ip := dst.Addr().Unmap()
if err := p.validateDialTarget(ip); err != nil {
return nil, fmt.Errorf("dial %s: %w", dst, err)
}
return p.dialer.DialContext(ctx, "tcp", dst.String())
}
// validateDialTarget blocks destinations that should never be dialed by the proxy.
// Mirrors the route validation in systemops.validateRoute.
func (p *Proxy) validateDialTarget(addr netip.Addr) error {
switch {
case !addr.IsValid():
return fmt.Errorf("invalid address")
case addr.IsLoopback():
return fmt.Errorf("loopback address not allowed")
case addr.IsLinkLocalUnicast(), addr.IsLinkLocalMulticast(), addr.IsInterfaceLocalMulticast():
return fmt.Errorf("link-local address not allowed")
case addr.IsMulticast():
return fmt.Errorf("multicast address not allowed")
case p.wgNetwork.IsValid() && p.wgNetwork.Contains(addr):
return fmt.Errorf("overlay network address not allowed")
case p.localIPs != nil && p.localIPs.IsLocalIP(addr):
return fmt.Errorf("local address not allowed")
}
return nil
}
func isClosedErr(err error) bool {
if err == nil {
return false
}
return err == io.EOF ||
err == io.ErrClosedPipe ||
err == net.ErrClosed ||
err == context.Canceled
}

View File

@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
@@ -327,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
a.config.DisableVNCAuth,
)
}

View File

@@ -543,11 +543,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DisableVNCAuth: config.DisableVNCAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
@@ -562,9 +564,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
InspectionCACertPath: config.InspectionCACertPath,
InspectionCAKeyPath: config.InspectionCAKeyPath,
ProfileConfig: config,
}
@@ -627,6 +626,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
@@ -639,6 +639,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
config.DisableVNCAuth,
)
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}

View File

@@ -31,7 +31,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/inspect"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -118,11 +117,13 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
DNSRouteInterval time.Duration
@@ -137,12 +138,6 @@ type EngineConfig struct {
MTU uint16
// InspectionCACertPath is a local CA cert for transparent proxy MITM.
// Takes priority over management-pushed CA.
InspectionCACertPath string
// InspectionCAKeyPath is the corresponding private key.
InspectionCAKeyPath string
// for debug bundle generation
ProfileConfig *profilemanager.Config
@@ -204,6 +199,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
statusRecorder *peer.Status
@@ -229,10 +225,6 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse
flowManager nftypes.FlowManager
// transparentProxy is the transparent forward proxy for traffic inspection.
transparentProxy *inspect.Proxy
udpInspectionHookID string
// auto-update
updateManager *updater.Manager
@@ -321,6 +313,10 @@ func (e *Engine) Stop() error {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -1008,6 +1004,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1020,6 +1017,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -1047,6 +1045,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1148,6 +1150,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1160,6 +1163,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1283,9 +1287,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Transparent proxy
e.updateTransparentProxy(networkMap.GetTransparentProxyConfig())
// Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
@@ -1337,6 +1338,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// VNC auth: use dedicated VNCAuth if present.
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
e.updateVNCServerAuth(vncAuth)
}
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1709,8 +1715,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
func (e *Engine) close() {
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
e.stopTransparentProxy()
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
@@ -1748,6 +1752,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1760,6 +1765,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -1,571 +0,0 @@
package internal
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/netip"
"net/url"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/inspect"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// updateTransparentProxy processes transparent proxy configuration from the network map.
func (e *Engine) updateTransparentProxy(cfg *mgmProto.TransparentProxyConfig) {
if cfg == nil || !cfg.Enabled {
if cfg == nil {
log.Tracef("inspect: config is nil")
} else {
log.Tracef("inspect: config disabled")
}
// Only stop if explicitly disabled. Don't stop on nil config to avoid
// a gap during policy edits where management briefly pushes empty config.
if cfg != nil && !cfg.Enabled {
e.stopTransparentProxy()
}
return
}
log.Debugf("inspect: config received: enabled=%v mode=%v default_action=%v rules=%d has_ca=%v",
cfg.Enabled, cfg.Mode, cfg.DefaultAction, len(cfg.Rules), len(cfg.CaCertPem) > 0)
// BlockInbound prevents adding TPROXY rules since kernel TPROXY bypasses ACLs.
// The userspace forwarder path still works as it operates within the forwarder hook.
if e.config.BlockInbound {
log.Warnf("inspect: BlockInbound is set, skipping redirect rules (userspace path still active)")
}
proxyConfig, err := toProxyConfig(cfg)
if err != nil {
log.Errorf("inspect: parse config: %v", err)
e.stopTransparentProxy()
return
}
// CA priority: local config > management-pushed > auto-generated self-signed.
// Local wins over mgmt to prevent compromised management from injecting a CA.
e.resolveInspectionCA(&proxyConfig)
if e.transparentProxy != nil {
// Mode change requires full recreate (envoy lifecycle, listener changes).
if proxyConfig.Mode != e.transparentProxy.Mode() {
log.Infof("inspect: mode changed to %s, recreating engine", proxyConfig.Mode)
e.stopTransparentProxy()
} else {
e.transparentProxy.UpdateConfig(proxyConfig)
e.syncTProxyRules(proxyConfig)
e.syncUDPInspectionHook()
return
}
}
if e.wgInterface != nil {
proxyConfig.WGNetwork = e.wgInterface.Address().Network
proxyConfig.ListenAddr = netip.AddrPortFrom(
e.wgInterface.Address().IP.Unmap(),
proxyConfig.ListenAddr.Port(),
)
}
// Pass local IP checker for SSRF prevention
if checker, ok := e.firewall.(inspect.LocalIPChecker); ok {
proxyConfig.LocalIPChecker = checker
}
p, err := inspect.New(e.ctx, log.WithField("component", "inspect"), proxyConfig)
if err != nil {
log.Errorf("inspect: start engine: %v", err)
return
}
e.transparentProxy = p
e.attachProxyToForwarder(p)
e.syncTProxyRules(proxyConfig)
e.syncUDPInspectionHook()
log.Infof("inspect: engine started (mode=%s, rules=%d)", proxyConfig.Mode, len(proxyConfig.Rules))
}
// stopTransparentProxy shuts down the transparent proxy and removes interception.
func (e *Engine) stopTransparentProxy() {
if e.transparentProxy == nil {
return
}
e.attachProxyToForwarder(nil)
e.removeTProxyRule()
e.removeUDPInspectionHook()
if err := e.transparentProxy.Close(); err != nil {
log.Debugf("inspect: close engine: %v", err)
}
e.transparentProxy = nil
log.Info("inspect: engine stopped")
}
const tproxyRuleID = "tproxy-redirect"
// syncTProxyRules adds a TPROXY rule via the firewall manager to intercept
// matching traffic on the WG interface and redirect it to the proxy socket.
func (e *Engine) syncTProxyRules(config inspect.Config) {
if e.config.BlockInbound {
e.removeTProxyRule()
return
}
var listenPort uint16
if e.transparentProxy != nil {
listenPort = e.transparentProxy.ListenPort()
}
if listenPort == 0 {
e.removeTProxyRule()
return
}
if e.firewall == nil {
return
}
dstPorts := make([]uint16, len(config.RedirectPorts))
copy(dstPorts, config.RedirectPorts)
log.Debugf("inspect: syncing redirect rules: listen port %d, redirect ports %v, sources %v",
listenPort, dstPorts, config.RedirectSources)
if err := e.firewall.AddTProxyRule(tproxyRuleID, config.RedirectSources, dstPorts, listenPort); err != nil {
log.Errorf("inspect: add redirect rule: %v", err)
return
}
}
// removeTProxyRule removes the TPROXY redirect rule.
func (e *Engine) removeTProxyRule() {
if e.firewall == nil {
return
}
if err := e.firewall.RemoveTProxyRule(tproxyRuleID); err != nil {
log.Debugf("inspect: remove redirect rule: %v", err)
}
}
// syncUDPInspectionHook registers a UDP packet hook on port 443 for QUIC SNI blocking.
// The hook is called by the USP filter for each UDP packet matching the port,
// allowing the inspection engine to extract QUIC SNI and block by domain.
func (e *Engine) syncUDPInspectionHook() {
e.removeUDPInspectionHook()
if e.firewall == nil || e.transparentProxy == nil {
return
}
p := e.transparentProxy
hookID := e.firewall.AddUDPInspectionHook(443, func(packet []byte) bool {
srcIP, dstIP, dstPort, udpPayload, ok := parseUDPPacket(packet)
if !ok {
return false
}
src := inspect.SourceInfo{IP: srcIP}
dst := netip.AddrPortFrom(dstIP, dstPort)
action := p.HandleUDPPacket(udpPayload, dst, src)
return action == inspect.ActionBlock
})
e.udpInspectionHookID = hookID
log.Debugf("inspect: registered UDP inspection hook on port 443 (id=%s)", hookID)
}
// removeUDPInspectionHook removes the QUIC inspection hook.
func (e *Engine) removeUDPInspectionHook() {
if e.udpInspectionHookID == "" || e.firewall == nil {
return
}
e.firewall.RemoveUDPInspectionHook(e.udpInspectionHookID)
e.udpInspectionHookID = ""
}
// parseUDPPacket extracts source/destination IP, destination port, and UDP
// payload from a raw IP packet. Supports both IPv4 and IPv6.
func parseUDPPacket(packet []byte) (srcIP, dstIP netip.Addr, dstPort uint16, payload []byte, ok bool) {
if len(packet) < 1 {
return srcIP, dstIP, 0, nil, false
}
version := packet[0] >> 4
var udpOffset int
switch version {
case 4:
if len(packet) < 20 {
return srcIP, dstIP, 0, nil, false
}
ihl := int(packet[0]&0x0f) * 4
if len(packet) < ihl+8 {
return srcIP, dstIP, 0, nil, false
}
var srcOK, dstOK bool
srcIP, srcOK = netip.AddrFromSlice(packet[12:16])
dstIP, dstOK = netip.AddrFromSlice(packet[16:20])
if !srcOK || !dstOK {
return srcIP, dstIP, 0, nil, false
}
udpOffset = ihl
case 6:
// IPv6 fixed header is 40 bytes. Next header must be UDP (17).
if len(packet) < 48 { // 40 header + 8 UDP
return srcIP, dstIP, 0, nil, false
}
nextHeader := packet[6]
if nextHeader != 17 { // not UDP (may have extension headers)
return srcIP, dstIP, 0, nil, false
}
var srcOK, dstOK bool
srcIP, srcOK = netip.AddrFromSlice(packet[8:24])
dstIP, dstOK = netip.AddrFromSlice(packet[24:40])
if !srcOK || !dstOK {
return srcIP, dstIP, 0, nil, false
}
udpOffset = 40
default:
return srcIP, dstIP, 0, nil, false
}
srcIP = srcIP.Unmap()
dstIP = dstIP.Unmap()
dstPort = uint16(packet[udpOffset+2])<<8 | uint16(packet[udpOffset+3])
payload = packet[udpOffset+8:]
return srcIP, dstIP, dstPort, payload, true
}
// attachProxyToForwarder sets or clears the proxy on the userspace forwarder.
func (e *Engine) attachProxyToForwarder(p *inspect.Proxy) {
type forwarderGetter interface {
GetForwarder() *forwarder.Forwarder
}
if fg, ok := e.firewall.(forwarderGetter); ok {
if fwd := fg.GetForwarder(); fwd != nil {
fwd.SetProxy(p)
}
}
}
// toProxyConfig converts a proto TransparentProxyConfig to the inspect.Config type.
func toProxyConfig(cfg *mgmProto.TransparentProxyConfig) (inspect.Config, error) {
config := inspect.Config{
Enabled: cfg.Enabled,
DefaultAction: toProxyAction(cfg.DefaultAction),
}
switch cfg.Mode {
case mgmProto.TransparentProxyMode_TP_MODE_ENVOY:
config.Mode = inspect.ModeEnvoy
case mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL:
config.Mode = inspect.ModeExternal
default:
config.Mode = inspect.ModeBuiltin
}
if cfg.ExternalProxyUrl != "" {
u, err := url.Parse(cfg.ExternalProxyUrl)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse external proxy URL: %w", err)
}
config.ExternalURL = u
}
for _, s := range cfg.RedirectSources {
prefix, err := netip.ParsePrefix(s)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse redirect source %q: %w", s, err)
}
config.RedirectSources = append(config.RedirectSources, prefix)
}
for _, p := range cfg.RedirectPorts {
config.RedirectPorts = append(config.RedirectPorts, uint16(p))
}
// TPROXY listen port: fixed default, overridable via env var.
if config.Mode == inspect.ModeBuiltin {
port := uint16(inspect.DefaultTProxyPort)
if v := os.Getenv("NB_TPROXY_PORT"); v != "" {
if p, err := strconv.ParseUint(v, 10, 16); err == nil {
port = uint16(p)
} else {
log.Warnf("invalid NB_TPROXY_PORT %q, using default %d", v, inspect.DefaultTProxyPort)
}
}
config.ListenAddr = netip.AddrPortFrom(netip.IPv4Unspecified(), port)
}
for _, r := range cfg.Rules {
rule, err := toProxyRule(r)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse rule %q: %w", r.Id, err)
}
config.Rules = append(config.Rules, rule)
}
if cfg.Icap != nil {
icapCfg, err := toICAPConfig(cfg.Icap)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse ICAP config: %w", err)
}
config.ICAP = icapCfg
}
if len(cfg.CaCertPem) > 0 && len(cfg.CaKeyPem) > 0 {
tlsCfg, err := parseTLSConfig(cfg.CaCertPem, cfg.CaKeyPem)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse TLS config: %w", err)
}
config.TLS = tlsCfg
}
if config.Mode == inspect.ModeEnvoy {
envCfg := &inspect.EnvoyConfig{
BinaryPath: cfg.EnvoyBinaryPath,
AdminPort: uint16(cfg.EnvoyAdminPort),
}
if cfg.EnvoySnippets != nil {
envCfg.Snippets = &inspect.EnvoySnippets{
HTTPFilters: cfg.EnvoySnippets.HttpFilters,
NetworkFilters: cfg.EnvoySnippets.NetworkFilters,
Clusters: cfg.EnvoySnippets.Clusters,
}
}
config.Envoy = envCfg
}
return config, nil
}
func toProxyRule(r *mgmProto.TransparentProxyRule) (inspect.Rule, error) {
rule := inspect.Rule{
ID: id.RuleID(r.Id),
Action: toProxyAction(r.Action),
Priority: int(r.Priority),
}
for _, d := range r.Domains {
dom, err := domain.FromString(d)
if err != nil {
return inspect.Rule{}, fmt.Errorf("parse domain %q: %w", d, err)
}
rule.Domains = append(rule.Domains, dom)
}
for _, n := range r.Networks {
prefix, err := netip.ParsePrefix(n)
if err != nil {
return inspect.Rule{}, fmt.Errorf("parse network %q: %w", n, err)
}
rule.Networks = append(rule.Networks, prefix)
}
for _, p := range r.Ports {
rule.Ports = append(rule.Ports, uint16(p))
}
for _, proto := range r.Protocols {
rule.Protocols = append(rule.Protocols, toProxyProtoType(proto))
}
rule.Paths = r.Paths
return rule, nil
}
func toProxyProtoType(p mgmProto.TransparentProxyProtocol) inspect.ProtoType {
switch p {
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTP:
return inspect.ProtoHTTP
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTPS:
return inspect.ProtoHTTPS
case mgmProto.TransparentProxyProtocol_TP_PROTO_H2:
return inspect.ProtoH2
case mgmProto.TransparentProxyProtocol_TP_PROTO_H3:
return inspect.ProtoH3
case mgmProto.TransparentProxyProtocol_TP_PROTO_WEBSOCKET:
return inspect.ProtoWebSocket
case mgmProto.TransparentProxyProtocol_TP_PROTO_OTHER:
return inspect.ProtoOther
default:
return ""
}
}
func toProxyAction(a mgmProto.TransparentProxyAction) inspect.Action {
switch a {
case mgmProto.TransparentProxyAction_TP_ACTION_BLOCK:
return inspect.ActionBlock
case mgmProto.TransparentProxyAction_TP_ACTION_INSPECT:
return inspect.ActionInspect
default:
return inspect.ActionAllow
}
}
func toICAPConfig(cfg *mgmProto.TransparentProxyICAPConfig) (*inspect.ICAPConfig, error) {
icap := &inspect.ICAPConfig{
MaxConnections: int(cfg.MaxConnections),
}
if cfg.ReqmodUrl != "" {
u, err := url.Parse(cfg.ReqmodUrl)
if err != nil {
return nil, fmt.Errorf("parse ICAP reqmod URL: %w", err)
}
icap.ReqModURL = u
}
if cfg.RespmodUrl != "" {
u, err := url.Parse(cfg.RespmodUrl)
if err != nil {
return nil, fmt.Errorf("parse ICAP respmod URL: %w", err)
}
icap.RespModURL = u
}
return icap, nil
}
func parseTLSConfig(certPEM, keyPEM []byte) (*inspect.TLSConfig, error) {
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, fmt.Errorf("decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode(keyPEM)
if keyBlock == nil {
return nil, fmt.Errorf("decode CA key PEM")
}
key, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
// Try PKCS8 as fallback
pkcs8Key, pkcs8Err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if pkcs8Err != nil {
return nil, fmt.Errorf("parse CA private key (tried EC and PKCS8): %w", err)
}
return &inspect.TLSConfig{CA: cert, CAKey: pkcs8Key}, nil
}
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
}
// resolveInspectionCA sets the TLS config on the proxy config using priority:
// 1. Local config file CA (InspectionCACertPath/InspectionCAKeyPath)
// 2. Management-pushed CA (already parsed in toProxyConfig)
// 3. Auto-generated self-signed CA (ephemeral, for testing)
// Local always wins to prevent a compromised management server from injecting a CA.
func (e *Engine) resolveInspectionCA(config *inspect.Config) {
// 1. Local CA from config file or env vars
certPath := e.config.InspectionCACertPath
keyPath := e.config.InspectionCAKeyPath
if certPath == "" {
certPath = os.Getenv("NB_INSPECTION_CA_CERT")
}
if keyPath == "" {
keyPath = os.Getenv("NB_INSPECTION_CA_KEY")
}
if certPath != "" && keyPath != "" {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Errorf("read local inspection CA cert %s: %v", certPath, err)
return
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
log.Errorf("read local inspection CA key %s: %v", keyPath, err)
return
}
tlsCfg, err := parseTLSConfig(certPEM, keyPEM)
if err != nil {
log.Errorf("parse local inspection CA: %v", err)
return
}
log.Infof("inspect: using local CA from %s", certPath)
config.TLS = tlsCfg
return
}
// 2. Management-pushed CA (already set by toProxyConfig)
if config.TLS != nil {
log.Infof("inspect: using management-pushed CA")
return
}
// 3. Auto-generate self-signed CA for testing / accept-cert UX
tlsCfg, err := generateSelfSignedCA()
if err != nil {
log.Errorf("generate self-signed inspection CA: %v", err)
return
}
log.Infof("inspect: using auto-generated self-signed CA (clients will see certificate warnings)")
config.TLS = tlsCfg
}
// generateSelfSignedCA creates an ephemeral ECDSA P-256 CA certificate.
// Clients will see certificate warnings but can choose to accept.
func generateSelfSignedCA() (*inspect.TLSConfig, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate CA key: %w", err)
}
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial: %w", err)
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
Organization: []string{"NetBird Transparent Proxy"},
CommonName: "NetBird Inspection CA (auto-generated)",
},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 0,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return nil, fmt.Errorf("create CA certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated CA certificate: %w", err)
}
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
}

View File

@@ -1,279 +0,0 @@
package internal
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/inspect"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestToProxyConfig_Basic(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Mode: mgmProto.TransparentProxyMode_TP_MODE_BUILTIN,
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_ALLOW,
RedirectSources: []string{
"10.0.0.0/24",
"192.168.1.0/24",
},
RedirectPorts: []uint32{80, 443},
Rules: []*mgmProto.TransparentProxyRule{
{
Id: "block-evil",
Domains: []string{"*.evil.com", "malware.example.com"},
Action: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
Priority: 1,
},
{
Id: "inspect-internal",
Domains: []string{"*.internal.corp"},
Networks: []string{"10.1.0.0/16"},
Ports: []uint32{443, 8443},
Action: mgmProto.TransparentProxyAction_TP_ACTION_INSPECT,
Priority: 10,
},
},
ListenPort: 8443,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.True(t, config.Enabled)
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
require.Len(t, config.RedirectSources, 2)
assert.Equal(t, "10.0.0.0/24", config.RedirectSources[0].String())
assert.Equal(t, "192.168.1.0/24", config.RedirectSources[1].String())
require.Len(t, config.RedirectPorts, 2)
assert.Equal(t, uint16(80), config.RedirectPorts[0])
assert.Equal(t, uint16(443), config.RedirectPorts[1])
require.Len(t, config.Rules, 2)
// Rule 1: block evil domains
assert.Equal(t, "block-evil", string(config.Rules[0].ID))
assert.Equal(t, inspect.ActionBlock, config.Rules[0].Action)
assert.Equal(t, 1, config.Rules[0].Priority)
require.Len(t, config.Rules[0].Domains, 2)
assert.Equal(t, "*.evil.com", config.Rules[0].Domains[0].PunycodeString())
assert.Equal(t, "malware.example.com", config.Rules[0].Domains[1].PunycodeString())
// Rule 2: inspect internal
assert.Equal(t, "inspect-internal", string(config.Rules[1].ID))
assert.Equal(t, inspect.ActionInspect, config.Rules[1].Action)
assert.Equal(t, 10, config.Rules[1].Priority)
require.Len(t, config.Rules[1].Networks, 1)
assert.Equal(t, "10.1.0.0/16", config.Rules[1].Networks[0].String())
require.Len(t, config.Rules[1].Ports, 2)
// Listen address
assert.True(t, config.ListenAddr.IsValid())
assert.Equal(t, uint16(8443), config.ListenAddr.Port())
}
func TestToProxyConfig_ExternalMode(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Mode: mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL,
ExternalProxyUrl: "http://proxy.corp:8080",
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.Equal(t, inspect.ModeExternal, config.Mode)
assert.Equal(t, inspect.ActionBlock, config.DefaultAction)
require.NotNil(t, config.ExternalURL)
assert.Equal(t, "http", config.ExternalURL.Scheme)
assert.Equal(t, "proxy.corp:8080", config.ExternalURL.Host)
}
func TestToProxyConfig_ICAP(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Icap: &mgmProto.TransparentProxyICAPConfig{
ReqmodUrl: "icap://icap-server:1344/reqmod",
RespmodUrl: "icap://icap-server:1344/respmod",
MaxConnections: 16,
},
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
require.NotNil(t, config.ICAP)
assert.Equal(t, "icap", config.ICAP.ReqModURL.Scheme)
assert.Equal(t, "icap-server:1344", config.ICAP.ReqModURL.Host)
assert.Equal(t, "/reqmod", config.ICAP.ReqModURL.Path)
assert.Equal(t, "/respmod", config.ICAP.RespModURL.Path)
assert.Equal(t, 16, config.ICAP.MaxConnections)
}
func TestToProxyConfig_Empty(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.True(t, config.Enabled)
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
assert.Empty(t, config.RedirectSources)
assert.Empty(t, config.RedirectPorts)
assert.Empty(t, config.Rules)
assert.Nil(t, config.ICAP)
assert.Nil(t, config.TLS)
assert.False(t, config.ListenAddr.IsValid())
}
func TestToProxyConfig_InvalidSource(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
RedirectSources: []string{"not-a-cidr"},
}
_, err := toProxyConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse redirect source")
}
func TestToProxyConfig_InvalidNetwork(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Rules: []*mgmProto.TransparentProxyRule{
{
Id: "bad",
Networks: []string{"not-a-cidr"},
},
},
}
_, err := toProxyConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse network")
}
func TestToProxyAction(t *testing.T) {
assert.Equal(t, inspect.ActionAllow, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_ALLOW))
assert.Equal(t, inspect.ActionBlock, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_BLOCK))
assert.Equal(t, inspect.ActionInspect, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_INSPECT))
// Unknown defaults to allow
assert.Equal(t, inspect.ActionAllow, toProxyAction(99))
}
func TestParseUDPPacket_IPv4(t *testing.T) {
// Build a minimal IPv4/UDP packet: 20-byte IPv4 header + 8-byte UDP header + payload
packet := make([]byte, 20+8+4)
// IPv4 header: version=4, IHL=5 (20 bytes)
packet[0] = 0x45
// Protocol = UDP (17)
packet[9] = 17
// Source IP: 10.0.0.1
packet[12], packet[13], packet[14], packet[15] = 10, 0, 0, 1
// Dest IP: 192.168.1.1
packet[16], packet[17], packet[18], packet[19] = 192, 168, 1, 1
// UDP source port: 54321 (0xD431)
packet[20] = 0xD4
packet[21] = 0x31
// UDP dest port: 443 (0x01BB)
packet[22] = 0x01
packet[23] = 0xBB
// Payload
packet[28] = 0xDE
packet[29] = 0xAD
packet[30] = 0xBE
packet[31] = 0xEF
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.Equal(t, "10.0.0.1", srcIP.String())
assert.Equal(t, "192.168.1.1", dstIP.String())
assert.Equal(t, uint16(443), dstPort)
assert.Equal(t, []byte{0xDE, 0xAD, 0xBE, 0xEF}, payload)
}
func TestParseUDPPacket_IPv6(t *testing.T) {
// Build a minimal IPv6/UDP packet: 40-byte IPv6 header + 8-byte UDP header + payload
packet := make([]byte, 40+8+4)
// Version = 6 (0x60 in high nibble)
packet[0] = 0x60
// Payload length: 8 (UDP header) + 4 (payload)
packet[4] = 0
packet[5] = 12
// Next header: UDP (17)
packet[6] = 17
// Source: 2001:db8::1
packet[8] = 0x20
packet[9] = 0x01
packet[10] = 0x0d
packet[11] = 0xb8
packet[23] = 0x01
// Dest: 2001:db8::2
packet[24] = 0x20
packet[25] = 0x01
packet[26] = 0x0d
packet[27] = 0xb8
packet[39] = 0x02
// UDP source port: 54321 (0xD431)
packet[40] = 0xD4
packet[41] = 0x31
// UDP dest port: 443 (0x01BB)
packet[42] = 0x01
packet[43] = 0xBB
// Payload
packet[48] = 0xCA
packet[49] = 0xFE
packet[50] = 0xBA
packet[51] = 0xBE
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.Equal(t, "2001:db8::1", srcIP.String())
assert.Equal(t, "2001:db8::2", dstIP.String())
assert.Equal(t, uint16(443), dstPort)
assert.Equal(t, []byte{0xCA, 0xFE, 0xBA, 0xBE}, payload)
}
func TestParseUDPPacket_TooShort(t *testing.T) {
_, _, _, _, ok := parseUDPPacket(nil)
assert.False(t, ok)
_, _, _, _, ok = parseUDPPacket([]byte{0x45, 0x00})
assert.False(t, ok)
}
func TestParseUDPPacket_IPv6ExtensionHeader(t *testing.T) {
// IPv6 with next header != UDP should be rejected
packet := make([]byte, 48)
packet[0] = 0x60
packet[6] = 6 // TCP, not UDP
_, _, _, _, ok := parseUDPPacket(packet)
assert.False(t, ok, "should reject IPv6 packets with non-UDP next header")
}
func TestParseUDPPacket_IPv4MappedIPv6(t *testing.T) {
// IPv4 packet with normal addresses should Unmap correctly
packet := make([]byte, 28)
packet[0] = 0x45
packet[9] = 17
packet[12], packet[13], packet[14], packet[15] = 127, 0, 0, 1
packet[16], packet[17], packet[18], packet[19] = 10, 0, 0, 1
packet[22] = 0x01
packet[23] = 0xBB
srcIP, dstIP, _, _, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.True(t, srcIP.Is4(), "should be plain IPv4, not mapped")
assert.True(t, dstIP.Is4(), "should be plain IPv4, not mapped")
}

View File

@@ -0,0 +1,247 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const (
vncExternalPort uint16 = 5900
vncInternalPort uint16 = 25900
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
// sshConf provides the JWT identity provider config (shared with SSH).
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
// Update JWT config on existing server in case management sent new config.
e.updateVNCServerJWT(sshConf)
return nil
}
return e.startVNCServer(sshConf)
}
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector := newPlatformVNC()
if capturer == nil || injector == nil {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
srv := vncserver.New(capturer, injector, "")
if vncNeedsServiceMode() {
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
srv.SetServiceMode(true)
}
// Configure VNC authentication.
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
log.Info("VNC: authentication disabled by config")
srv.SetDisableAuth(true)
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
srv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
srv.SetNetstackNet(netstackNet)
}
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// updateVNCServerJWT configures the JWT validation for the VNC server using
// the same JWT config as SSH (same identity provider).
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
vncSrv.SetDisableAuth(true)
return
}
protoJWT := sshConf.GetJwtConfig()
if protoJWT == nil {
return
}
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if vncAuth == nil || e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
UserIDClaim: vncAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
})
}
// GetVNCServerStatus returns whether the VNC server is running.
func (e *Engine) GetVNCServerStatus() bool {
return e.vncSrv != nil
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}

View File

@@ -0,0 +1,23 @@
//go:build darwin && !ios
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin && !freebsd && !(linux && !android)
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return nil, nil
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -0,0 +1,23 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewX11Poller("")
injector, err := vncserver.NewX11InputInjector("")
if err != nil {
log.Debugf("VNC: X11 input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -64,11 +64,13 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
@@ -97,9 +99,6 @@ type ConfigInput struct {
LazyConnectionEnabled *bool
MTU *uint16
InspectionCACertPath string
InspectionCAKeyPath string
}
// Config Configuration type
@@ -117,11 +116,13 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
@@ -174,13 +175,6 @@ type Config struct {
LazyConnectionEnabled bool
MTU uint16
// InspectionCACertPath is the path to a PEM CA certificate for transparent proxy MITM.
// Local CA takes priority over management-pushed CA.
InspectionCACertPath string
// InspectionCAKeyPath is the path to the PEM CA private key for transparent proxy MITM.
InspectionCAKeyPath string
}
var ConfigDirOverride string
@@ -425,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.True()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
@@ -475,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
if *input.DisableVNCAuth {
log.Infof("disabling VNC authentication")
} else {
log.Infof("enabling VNC authentication")
}
config.DisableVNCAuth = input.DisableVNCAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
@@ -613,17 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.InspectionCACertPath != "" && input.InspectionCACertPath != config.InspectionCACertPath {
log.Infof("updating inspection CA cert path to %s", input.InspectionCACertPath)
config.InspectionCACertPath = input.InspectionCACertPath
updated = true
}
if input.InspectionCAKeyPath != "" && input.InspectionCAKeyPath != config.InspectionCAKeyPath {
log.Infof("updating inspection CA key path to %s", input.InspectionCAKeyPath)
config.InspectionCAKeyPath = input.InspectionCAKeyPath
updated = true
}
return updated, nil
}

View File

@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
m.notifier.SetFakeIPRoute(fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)

View File

@@ -16,6 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoute *route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
n.fakeIPRoute = r
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
}
allRoutes := slices.Clone(n.currentRoutes)
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
if n.fakeIPRoute != nil {
allRoutes = append(allRoutes, n.fakeIPRoute)
}
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {

View File

@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on iOS
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}

View File

@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}

File diff suppressed because it is too large Load Diff

View File

@@ -209,6 +209,9 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverVNCAllowed = 41;
optional bool disableVNCAuth = 42;
}
message LoginResponse {
@@ -316,6 +319,10 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverVNCAllowed = 28;
bool disableVNCAuth = 29;
}
// PeerState contains the latest state of a peer
@@ -394,6 +401,11 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -408,6 +420,7 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -677,6 +690,9 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverVNCAllowed = 36;
optional bool disableVNCAuth = 37;
}
message SetConfigResponse{}

View File

@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -382,6 +383,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.DisableVNCAuth != nil {
config.DisableVNCAuth = msg.DisableVNCAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
@@ -1120,6 +1124,7 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1159,6 +1164,26 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
return &proto.VNCServerState{
Enabled: engine.GetVNCServerStatus(),
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1500,6 +1525,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableSSHAuth = *cfg.DisableSSHAuth
}
disableVNCAuth := false
if cfg.DisableVNCAuth != nil {
disableVNCAuth = *cfg.DisableVNCAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
@@ -1514,6 +1544,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
@@ -1529,6 +1560,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
DisableVNCAuth: disableVNCAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}

View File

@@ -58,6 +58,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCAuth := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -82,6 +84,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCAuth: &disableVNCAuth,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -125,6 +129,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCAuth)
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -176,6 +184,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCAuth": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -236,6 +246,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-auth": "DisableVNCAuth",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
}
}
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
// generateS4UUserToken creates a Windows token using S4U authentication.
// This is the same approach OpenSSH for Windows uses for public key authentication.
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)

View File

@@ -507,27 +507,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
maxTokenAge = DefaultJWTMaxTokenAge
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
@@ -558,27 +538,7 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
return jwt.UserIDFromToken(token)
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {

View File

@@ -130,6 +130,10 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -151,6 +155,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -171,6 +176,9 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := VNCServerStateOutput{
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
}
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -194,6 +202,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -524,6 +533,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncServerStatus = "Enabled"
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -553,6 +567,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -570,6 +585,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,

View File

@@ -398,6 +398,9 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false
}
}`
// @formatter:on
@@ -505,6 +508,8 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
`
assert.Equal(t, expectedYAML, yaml)
@@ -572,6 +577,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -596,6 +602,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -63,6 +63,7 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -78,21 +79,27 @@ type Info struct {
EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
DisableVNCAuth bool
}
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
disableVNCAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes
@@ -118,6 +125,9 @@ func (i *Info) SetFlags(
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
if disableVNCAuth != nil {
i.DisableVNCAuth = *disableVNCAuth
}
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context

View File

@@ -0,0 +1,474 @@
//go:build windows
package server
import (
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
agentPort = "15900"
// agentTokenLen is the length of the random authentication token
// used to verify that connections to the agent come from the service.
agentTokenLen = 32
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
type wtsSessionInfo struct {
SessionID uint32
WinStationName [66]byte // actually *uint16, but we just need the struct size
State uint32
}
// getActiveSessionID returns the session ID of the best session to attach to.
// Prefers an active (logged-in, interactive) session over the console session.
// This avoids kicking out an RDP user when the console is at the login screen.
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer procWTSFreeMemory.Call(sessionInfo)
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
// Find the first active session (not session 0, which is the services session).
var bestID uint32
found := false
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State == wtsActive {
bestID = s.SessionID
found = true
break
}
}
if !found {
return getConsoleSessionID()
}
return bestID
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns a new block pointer with the entry added.
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return uintptr(unsafe.Pointer(&newBlock[0]))
}
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, _ := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r != 0 {
defer procDestroyEnvironmentBlock.Call(envBlock)
}
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic).
if r != 0 {
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
}
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow,
envPtr, nil, &si, &pi,
)
// Close the write end in the parent so reads will get EOF when the child exits.
windows.CloseHandle(stderrWrite)
if err != nil {
windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one.
type sessionManager struct {
port string
mu sync.Mutex
agentProc windows.Handle
sessionID uint32
authToken string
done chan struct{}
}
func newSessionManager(port string) *sessionManager {
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
}
// generateAuthToken creates a new random hex token for agent authentication.
func generateAuthToken() string {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
log.Warnf("generate agent auth token: %v", err)
return ""
}
return hex.EncodeToString(b)
}
// AuthToken returns the current agent authentication token.
func (m *sessionManager) AuthToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.authToken
}
// Stop signals the session manager to exit its polling loop.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
sid := getActiveSessionID()
m.mu.Lock()
if sid != m.sessionID {
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
if m.agentProc != 0 {
var code uint32
_ = windows.GetExitCodeProcess(m.agentProc, &code)
if code != stillActive {
log.Infof("agent exited (code=%d), respawning", code)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
}
}
if m.agentProc == 0 && sid != 0xFFFFFFFF {
m.authToken = generateAuthToken()
h, err := spawnAgentInSession(sid, m.port, m.authToken)
if err != nil {
log.Warnf("spawn agent in session %d: %v", sid, err)
m.authToken = ""
} else {
m.agentProc = h
}
}
m.mu.Unlock()
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
func (m *sessionManager) killAgent() {
if m.agentProc != 0 {
_ = windows.TerminateProcess(m.agentProc, 0)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
log.Info("killed old agent")
}
}
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
// relogs them at the correct level with the service's formatter.
func relogAgentOutput(pipe windows.Handle) {
defer windows.CloseHandle(pipe)
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer f.Close()
entry := log.WithField("component", "vnc-agent")
dec := json.NewDecoder(f)
for dec.More() {
var m map[string]any
if err := dec.Decode(&m); err != nil {
break
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
// Forward extra fields from the agent (skip standard logrus fields).
// Remap "caller" to "source" so it doesn't conflict with logrus internals
// but still shows the original file/line from the agent process.
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// proxyToAgent connects to the agent, sends the auth token, then proxies
// the VNC client connection bidirectionally.
func proxyToAgent(client net.Conn, port string, authToken string) {
defer client.Close()
addr := "127.0.0.1:" + port
var agentConn net.Conn
var err error
for range 50 {
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
if err == nil {
break
}
time.Sleep(200 * time.Millisecond)
}
if err != nil {
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
return
}
defer agentConn.Close()
// Send the auth token so the agent can verify this connection
// comes from the trusted service process.
tokenBytes, _ := hex.DecodeString(authToken)
if _, err := agentConn.Write(tokenBytes); err != nil {
log.Warnf("send auth token to agent: %v", err)
return
}
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client→agent", agentConn, client)
go cp("agent→client", client, agentConn)
<-done
}

View File

@@ -0,0 +1,274 @@
//go:build darwin && !ios
package server
import (
"fmt"
"image"
"sync"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var darwinCaptureOnce sync.Once
var (
cgMainDisplayID func() uint32
cgDisplayPixelsWide func(uint32) uintptr
cgDisplayPixelsHigh func(uint32) uintptr
cgDisplayCreateImage func(uint32) uintptr
cgImageGetWidth func(uintptr) uintptr
cgImageGetHeight func(uintptr) uintptr
cgImageGetBytesPerRow func(uintptr) uintptr
cgImageGetBitsPerPixel func(uintptr) uintptr
cgImageGetDataProvider func(uintptr) uintptr
cgDataProviderCopyData func(uintptr) uintptr
cgImageRelease func(uintptr)
cfDataGetLength func(uintptr) int64
cfDataGetBytePtr func(uintptr) uintptr
cfRelease func(uintptr)
cgPreflightScreenCaptureAccess func() bool
cgRequestScreenCaptureAccess func() bool
darwinCaptureReady bool
)
func initDarwinCapture() {
darwinCaptureOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation: %v", err)
return
}
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
}
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
}
darwinCaptureReady = true
})
}
// CGCapturer captures the macOS main display using Core Graphics.
type CGCapturer struct {
displayID uint32
w, h int
}
// NewCGCapturer creates a screen capturer for the main display.
func NewCGCapturer() (*CGCapturer, error) {
initDarwinCapture()
if !darwinCaptureReady {
return nil, fmt.Errorf("CoreGraphics not available")
}
// Request Screen Recording permission (shows system dialog on macOS 11+).
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
log.Warn("Screen Recording permission not granted. " +
"Grant in System Settings > Privacy & Security > Screen Recording, then restart.")
}
displayID := cgMainDisplayID()
w := int(cgDisplayPixelsWide(displayID))
h := int(cgDisplayPixelsHigh(displayID))
if w == 0 || h == 0 {
return nil, fmt.Errorf("display dimensions are zero")
}
log.Infof("macOS capturer ready: %dx%d (display=%d)", w, h, displayID)
return &CGCapturer{displayID: displayID, w: w, h: h}, nil
}
// Width returns the screen width.
func (c *CGCapturer) Width() int { return c.w }
// Height returns the screen height.
func (c *CGCapturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *CGCapturer) Capture() (*image.RGBA, error) {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return nil, fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
img := image.NewRGBA(image.Rect(0, 0, w, h))
bytesPerPixel := bpp / 8
for row := 0; row < h; row++ {
srcOff := row * bytesPerRow
dstOff := row * img.Stride
for col := 0; col < w; col++ {
si := srcOff + col*bytesPerPixel
di := dstOff + col*4
img.Pix[di+0] = src[si+2] // R (from BGRA)
img.Pix[di+1] = src[si+1] // G
img.Pix[di+2] = src[si+0] // B
img.Pix[di+3] = 0xff
}
}
return img, nil
}
// MacPoller wraps CGCapturer in a continuous capture loop.
type MacPoller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
done chan struct{}
}
// NewMacPoller creates a capturer that continuously grabs the macOS display.
func NewMacPoller() *MacPoller {
p := &MacPoller{done: make(chan struct{})}
go p.loop()
return p
}
// Close stops the capture loop.
func (p *MacPoller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *MacPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *MacPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *MacPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *MacPoller) loop() {
var capturer *CGCapturer
var initFails int
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewCGCapturer()
if err != nil {
initFails++
if initFails <= maxCapturerRetries {
log.Debugf("macOS capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
select {
case <-p.done:
return
case <-time.After(2 * time.Second):
}
continue
}
log.Warnf("macOS capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
return
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if err != nil {
log.Debugf("macOS capture: %v", err)
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}
var _ ScreenCapturer = (*MacPoller)(nil)

View File

@@ -0,0 +1,99 @@
//go:build windows
package server
import (
"errors"
"fmt"
"image"
"github.com/kirides/go-d3d/d3d11"
"github.com/kirides/go-d3d/outputduplication"
)
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
// Provides GPU-accelerated capture with native dirty rect tracking.
// Only works from the interactive user session, not Session 0.
//
// Uses a double-buffer: DXGI writes into img, then we copy to the current
// output buffer and hand it out. Alternating between two output buffers
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
type dxgiCapturer struct {
dup *outputduplication.OutputDuplicator
device *d3d11.ID3D11Device
ctx *d3d11.ID3D11DeviceContext
img *image.RGBA
out [2]*image.RGBA
outIdx int
width int
height int
}
func newDXGICapturer() (*dxgiCapturer, error) {
device, deviceCtx, err := d3d11.NewD3D11Device()
if err != nil {
return nil, fmt.Errorf("create D3D11 device: %w", err)
}
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
if err != nil {
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("create output duplication: %w", err)
}
w, h := screenSize()
if w == 0 || h == 0 {
dup.Release()
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("screen dimensions are zero")
}
rect := image.Rect(0, 0, w, h)
c := &dxgiCapturer{
dup: dup,
device: device,
ctx: deviceCtx,
img: image.NewRGBA(rect),
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
width: w,
height: h,
}
// Grab the initial frame with a longer timeout to ensure we have
// a valid image before returning.
_ = dup.GetImage(c.img, 2000)
return c, nil
}
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
err := c.dup.GetImage(c.img, 100)
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
return nil, err
}
// Copy into the next output buffer. The DesktopCapturer hands out the
// returned pointer to VNC sessions that read pixels concurrently, so we
// alternate between two pre-allocated buffers instead of allocating per frame.
out := c.out[c.outIdx]
c.outIdx ^= 1
copy(out.Pix, c.img.Pix)
return out, nil
}
func (c *dxgiCapturer) close() {
if c.dup != nil {
c.dup.Release()
c.dup = nil
}
if c.ctx != nil {
c.ctx.Release()
c.ctx = nil
}
if c.device != nil {
c.device.Release()
c.device = nil
}
}

View File

@@ -0,0 +1,461 @@
//go:build windows
package server
import (
"fmt"
"image"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
user32 = windows.NewLazySystemDLL("user32.dll")
procGetDC = user32.NewProc("GetDC")
procReleaseDC = user32.NewProc("ReleaseDC")
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
procSelectObject = gdi32.NewProc("SelectObject")
procDeleteObject = gdi32.NewProc("DeleteObject")
procDeleteDC = gdi32.NewProc("DeleteDC")
procBitBlt = gdi32.NewProc("BitBlt")
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
// Desktop switching for service/Session 0 capture.
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
procCloseDesktop = user32.NewProc("CloseDesktop")
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
procCloseWindowStation = user32.NewProc("CloseWindowStation")
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
)
const uoiName = 2
const (
smCxScreen = 0
smCyScreen = 1
srccopy = 0x00CC0020
dibRgbColors = 0
)
type bitmapInfoHeader struct {
Size uint32
Width int32
Height int32
Planes uint16
BitCount uint16
Compression uint32
SizeImage uint32
XPelsPerMeter int32
YPelsPerMeter int32
ClrUsed uint32
ClrImportant uint32
}
type bitmapInfo struct {
Header bitmapInfoHeader
}
// setupInteractiveWindowStation associates the current process with WinSta0,
// the interactive window station. This is required for a SYSTEM service in
// Session 0 to call OpenInputDesktop for screen capture and input injection.
func setupInteractiveWindowStation() error {
name, err := windows.UTF16PtrFromString("WinSta0")
if err != nil {
return fmt.Errorf("UTF16 WinSta0: %w", err)
}
hWinSta, _, err := procOpenWindowStation.Call(
uintptr(unsafe.Pointer(name)),
0,
uintptr(windows.MAXIMUM_ALLOWED),
)
if hWinSta == 0 {
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
}
r, _, err := procSetProcessWindowStation.Call(hWinSta)
if r == 0 {
procCloseWindowStation.Call(hWinSta)
return fmt.Errorf("SetProcessWindowStation: %w", err)
}
log.Info("process window station set to WinSta0 (interactive)")
return nil
}
func screenSize() (int, int) {
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
return int(w), int(h)
}
func getDesktopName(hDesk uintptr) string {
var buf [256]uint16
var needed uint32
procGetUserObjectInformationW.Call(hDesk, uoiName,
uintptr(unsafe.Pointer(&buf[0])), 512,
uintptr(unsafe.Pointer(&needed)))
return windows.UTF16ToString(buf[:])
}
// switchToInputDesktop opens the desktop currently receiving user input
// and sets it as the calling OS thread's desktop. Must be called from a
// goroutine locked to its OS thread via runtime.LockOSThread().
func switchToInputDesktop() (bool, string) {
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
if hDesk == 0 {
return false, ""
}
name := getDesktopName(hDesk)
ret, _, _ := procSetThreadDesktop.Call(hDesk)
procCloseDesktop.Call(hDesk)
return ret != 0, name
}
// gdiCapturer captures the desktop screen using GDI BitBlt.
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
type gdiCapturer struct {
mu sync.Mutex
width int
height int
// Pre-allocated GDI resources, reused across captures.
memDC uintptr
bmp uintptr
bits uintptr
}
func newGDICapturer() (*gdiCapturer, error) {
w, h := screenSize()
if w == 0 || h == 0 {
return nil, fmt.Errorf("screen dimensions are zero")
}
c := &gdiCapturer{width: w, height: h}
if err := c.allocGDI(); err != nil {
return nil, err
}
return c, nil
}
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
func (c *gdiCapturer) allocGDI() error {
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
memDC, _, _ := procCreateCompatDC.Call(screenDC)
if memDC == 0 {
return fmt.Errorf("CreateCompatibleDC returned 0")
}
bi := bitmapInfo{
Header: bitmapInfoHeader{
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
Width: int32(c.width),
Height: -int32(c.height), // negative = top-down DIB
Planes: 1,
BitCount: 32,
},
}
var bits uintptr
bmp, _, _ := procCreateDIBSection.Call(
screenDC,
uintptr(unsafe.Pointer(&bi)),
dibRgbColors,
uintptr(unsafe.Pointer(&bits)),
0, 0,
)
if bmp == 0 || bits == 0 {
procDeleteDC.Call(memDC)
return fmt.Errorf("CreateDIBSection returned 0")
}
procSelectObject.Call(memDC, bmp)
c.memDC = memDC
c.bmp = bmp
c.bits = bits
return nil
}
func (c *gdiCapturer) close() { c.freeGDI() }
// freeGDI releases pre-allocated GDI resources.
func (c *gdiCapturer) freeGDI() {
if c.bmp != 0 {
procDeleteObject.Call(c.bmp)
c.bmp = 0
}
if c.memDC != 0 {
procDeleteDC.Call(c.memDC)
c.memDC = 0
}
c.bits = 0
}
func (c *gdiCapturer) capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.memDC == 0 {
return nil, fmt.Errorf("GDI resources not allocated")
}
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return nil, fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
screenDC, 0, 0, srccopy)
if ret == 0 {
return nil, fmt.Errorf("BitBlt returned 0")
}
n := c.width * c.height * 4
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
// Swap R and B in bulk using uint32 operations (one load + mask + shift
// per pixel instead of three separate byte assignments).
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
pix := img.Pix
copy(pix, raw)
swizzleBGRAtoRGBA(pix)
return img, nil
}
// DesktopCapturer captures the interactive desktop, handling desktop transitions
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
// captures frames, which are retrieved by the VNC session on demand.
// Capture pauses automatically when no clients are connected.
type DesktopCapturer struct {
mu sync.Mutex
frame *image.RGBA
w, h int
// clients tracks the number of active VNC sessions. When zero, the
// capture loop idles instead of grabbing frames.
clients atomic.Int32
// wake is signaled when a client connects and the loop should resume.
wake chan struct{}
// done is closed when Close is called, terminating the capture loop.
done chan struct{}
}
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
func NewDesktopCapturer() *DesktopCapturer {
c := &DesktopCapturer{
wake: make(chan struct{}, 1),
done: make(chan struct{}),
}
go c.loop()
return c
}
// ClientConnect increments the active client count, resuming capture if needed.
func (c *DesktopCapturer) ClientConnect() {
c.clients.Add(1)
select {
case c.wake <- struct{}{}:
default:
}
}
// ClientDisconnect decrements the active client count.
func (c *DesktopCapturer) ClientDisconnect() {
c.clients.Add(-1)
}
// Close stops the capture loop and releases resources.
func (c *DesktopCapturer) Close() {
select {
case <-c.done:
default:
close(c.done)
}
}
// Width returns the current screen width.
func (c *DesktopCapturer) Width() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.w
}
// Height returns the current screen height.
func (c *DesktopCapturer) Height() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.h
}
// Capture returns the most recent desktop frame.
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
img := c.frame
c.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
// waitForClient blocks until a client connects or the capturer is closed.
func (c *DesktopCapturer) waitForClient() bool {
if c.clients.Load() > 0 {
return true
}
select {
case <-c.wake:
return true
case <-c.done:
return false
}
}
func (c *DesktopCapturer) loop() {
runtime.LockOSThread()
// When running as a Windows service (Session 0), we need to attach to the
// interactive window station before OpenInputDesktop will succeed.
if err := setupInteractiveWindowStation(); err != nil {
log.Warnf("attach to interactive window station: %v", err)
}
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
defer frameTicker.Stop()
retryTimer := time.NewTimer(0)
retryTimer.Stop()
defer retryTimer.Stop()
type frameCapturer interface {
capture() (*image.RGBA, error)
close()
}
var cap frameCapturer
var desktopFails int
var lastDesktop string
createCapturer := func() (frameCapturer, error) {
dc, err := newDXGICapturer()
if err == nil {
log.Info("using DXGI Desktop Duplication for capture")
return dc, nil
}
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
gc, err := newGDICapturer()
if err != nil {
return nil, err
}
log.Info("using GDI BitBlt for capture")
return gc, nil
}
for {
if !c.waitForClient() {
if cap != nil {
cap.close()
}
return
}
// No clients: release the capturer and wait.
if c.clients.Load() <= 0 {
if cap != nil {
cap.close()
cap = nil
}
continue
}
ok, desk := switchToInputDesktop()
if !ok {
desktopFails++
if desktopFails == 1 || desktopFails%100 == 0 {
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
}
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
if desktopFails > 0 {
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
desktopFails = 0
}
if desk != lastDesktop {
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
lastDesktop = desk
if cap != nil {
cap.close()
}
cap = nil
}
if cap == nil {
fc, err := createCapturer()
if err != nil {
log.Warnf("create capturer: %v", err)
retryTimer.Reset(500 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
cap = fc
w, h := screenSize()
c.mu.Lock()
c.w, c.h = w, h
c.mu.Unlock()
log.Infof("screen capturer ready: %dx%d", w, h)
}
img, err := cap.capture()
if err != nil {
log.Debugf("capture: %v", err)
cap.close()
cap = nil
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
c.mu.Lock()
c.frame = img
c.mu.Unlock()
select {
case <-frameTicker.C:
case <-c.done:
if cap != nil {
cap.close()
}
return
}
}
}

View File

@@ -0,0 +1,385 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"image"
"os"
"os/exec"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
)
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
type X11Capturer struct {
mu sync.Mutex
conn *xgb.Conn
screen *xproto.ScreenInfo
w, h int
shmID int
shmAddr []byte
shmSeg uint32 // shm.Seg
useSHM bool
}
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
// environment variables if needed. This is required when running as a system
// service where these vars aren't set.
func detectX11Display() {
if os.Getenv("DISPLAY") != "" {
return
}
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
if detectX11FromProc() {
return
}
if detectX11FromSockets() {
return
}
}
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
func detectX11FromProc() bool {
entries, err := os.ReadDir("/proc")
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() {
continue
}
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
if err != nil {
continue
}
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
setDisplayEnv(display, auth)
return true
}
}
return false
}
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
// to find the auth file. Works on FreeBSD and other systems without /proc.
func detectX11FromSockets() bool {
entries, err := os.ReadDir("/tmp/.X11-unix")
if err != nil {
return false
}
// Find the lowest display number.
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
display := ":" + name[1:]
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
// Try to find -auth from ps output.
if auth := findXorgAuthFromPS(); auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
}
return true
}
return false
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
func findXorgAuthFromPS() string {
out, err := exec.Command("ps", "auxww").Output()
if err != nil {
return ""
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
continue
}
fields := strings.Fields(line)
for i, f := range fields {
if f == "-auth" && i+1 < len(fields) {
return fields[i+1]
}
}
}
return ""
}
func parseXorgArgs(args []string) (display, auth string) {
if len(args) == 0 {
return "", ""
}
base := args[0]
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
return "", ""
}
for i, arg := range args[1:] {
if len(arg) > 0 && arg[0] == ':' {
display = arg
}
if arg == "-auth" && i+2 < len(args) {
auth = args[i+2]
}
}
return display, auth
}
func setDisplayEnv(display, auth string) {
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s", display)
if auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s", auth)
}
}
func splitCmdline(data []byte) []string {
var args []string
for _, b := range splitNull(data) {
if len(b) > 0 {
args = append(args, string(b))
}
}
return args
}
func splitNull(data []byte) [][]byte {
var parts [][]byte
start := 0
for i, b := range data {
if b == 0 {
parts = append(parts, data[start:i])
start = i + 1
}
}
if start < len(data) {
parts = append(parts, data[start:])
}
return parts
}
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
func NewX11Capturer(display string) (*X11Capturer, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
c := &X11Capturer{
conn: conn,
screen: &screen,
w: int(screen.WidthInPixels),
h: int(screen.HeightInPixels),
}
if err := c.initSHM(); err != nil {
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
}
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
return c, nil
}
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
// the capturer falls back to GetImage.
// Width returns the screen width.
func (c *X11Capturer) Width() int { return c.w }
// Height returns the screen height.
func (c *X11Capturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *X11Capturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.useSHM {
return c.captureSHM()
}
return c.captureGetImage()
}
// captureSHM is implemented in capture_x11_shm_linux.go.
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
data := reply.Data
n := c.w * c.h * 4
if len(data) < n {
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
}
for i := 0; i < n; i += 4 {
img.Pix[i+0] = data[i+2] // R
img.Pix[i+1] = data[i+1] // G
img.Pix[i+2] = data[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
// Close releases X11 resources.
func (c *X11Capturer) Close() {
c.closeSHM()
c.conn.Close()
}
// closeSHM is implemented in capture_x11_shm_linux.go.
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
// DesktopCapturer pattern from Windows.
type X11Poller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
display string
done chan struct{}
}
// NewX11Poller creates a capturer that continuously grabs the X11 display.
func NewX11Poller(display string) *X11Poller {
p := &X11Poller{
display: display,
done: make(chan struct{}),
}
go p.loop()
return p
}
// Close stops the capture loop.
func (p *X11Poller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *X11Poller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *X11Poller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *X11Poller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *X11Poller) loop() {
var capturer *X11Capturer
var initFails int
defer func() {
if capturer != nil {
capturer.Close()
}
}()
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewX11Capturer(p.display)
if err != nil {
initFails++
if initFails <= maxCapturerRetries {
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
select {
case <-p.done:
return
case <-time.After(2 * time.Second):
}
continue
}
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
return
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if err != nil {
log.Debugf("X11 capture: %v", err)
capturer.Close()
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}

View File

@@ -0,0 +1,78 @@
//go:build linux && !android
package server
import (
"fmt"
"image"
"github.com/jezek/xgb/shm"
"github.com/jezek/xgb/xproto"
"golang.org/x/sys/unix"
)
func (c *X11Capturer) initSHM() error {
if err := shm.Init(c.conn); err != nil {
return fmt.Errorf("init SHM extension: %w", err)
}
size := c.w * c.h * 4
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
if err != nil {
return fmt.Errorf("shmget: %w", err)
}
addr, err := unix.SysvShmAttach(id, 0, 0)
if err != nil {
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
return fmt.Errorf("shmat: %w", err)
}
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
seg, err := shm.NewSegId(c.conn)
if err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("new SHM seg: %w", err)
}
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("SHM attach to X: %w", err)
}
c.shmID = id
c.shmAddr = addr
c.shmSeg = uint32(seg)
c.useSHM = true
return nil
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
_, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("SHM GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
n := c.w * c.h * 4
for i := 0; i < n; i += 4 {
img.Pix[i+0] = c.shmAddr[i+2] // R
img.Pix[i+1] = c.shmAddr[i+1] // G
img.Pix[i+2] = c.shmAddr[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
func (c *X11Capturer) closeSHM() {
if c.useSHM {
shm.Detach(c.conn, shm.Seg(c.shmSeg))
unix.SysvShmDetach(c.shmAddr)
}
}

View File

@@ -0,0 +1,18 @@
//go:build freebsd
package server
import (
"fmt"
"image"
)
func (c *X11Capturer) initSHM() error {
return fmt.Errorf("SysV SHM not available on this platform")
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
return nil, fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) closeSHM() {}

View File

@@ -0,0 +1,403 @@
//go:build darwin && !ios
package server
import (
"fmt"
"os/exec"
"strings"
"sync"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
// Core Graphics event constants.
const (
kCGEventSourceStateCombinedSessionState int32 = 0
kCGEventLeftMouseDown int32 = 1
kCGEventLeftMouseUp int32 = 2
kCGEventRightMouseDown int32 = 3
kCGEventRightMouseUp int32 = 4
kCGEventMouseMoved int32 = 5
kCGEventLeftMouseDragged int32 = 6
kCGEventRightMouseDragged int32 = 7
kCGEventKeyDown int32 = 10
kCGEventKeyUp int32 = 11
kCGEventOtherMouseDown int32 = 25
kCGEventOtherMouseUp int32 = 26
kCGMouseButtonLeft int32 = 0
kCGMouseButtonRight int32 = 1
kCGMouseButtonCenter int32 = 2
kCGHIDEventTap int32 = 0
)
var darwinInputOnce sync.Once
var (
cgEventSourceCreate func(int32) uintptr
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
// purego can't handle array/struct types but individual float64s work.
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
cgEventPost func(int32, uintptr)
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
cgEventCreateScrollWheelEventAddr uintptr
darwinInputReady bool
darwinEventSource uintptr
)
func initDarwinInput() {
darwinInputOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics for input: %v", err)
return
}
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
if err == nil {
cgEventCreateScrollWheelEventAddr = sym
}
darwinInputReady = true
})
}
func ensureEventSource() uintptr {
if darwinEventSource != 0 {
return darwinEventSource
}
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
return darwinEventSource
}
// MacInputInjector injects keyboard and mouse events via Core Graphics.
type MacInputInjector struct {
lastButtons uint8
pbcopyPath string
pbpastePath string
}
// NewMacInputInjector creates a macOS input injector.
func NewMacInputInjector() (*MacInputInjector, error) {
initDarwinInput()
if !darwinInputReady {
return nil, fmt.Errorf("CoreGraphics not available for input injection")
}
checkMacPermissions()
m := &MacInputInjector{}
if path, err := exec.LookPath("pbcopy"); err == nil {
m.pbcopyPath = path
}
if path, err := exec.LookPath("pbpaste"); err == nil {
m.pbpastePath = path
}
if m.pbcopyPath == "" || m.pbpastePath == "" {
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
}
log.Info("macOS input injector ready")
return m, nil
}
// checkMacPermissions logs warnings and triggers the Accessibility prompt.
// Screen Recording has no programmatic prompt, the user must grant it manually.
func checkMacPermissions() {
// Check Accessibility via osascript (triggers the system prompt dialog).
out, err := exec.Command("osascript", "-e",
`tell application "System Events" to return name of first process`).CombinedOutput()
if err != nil {
log.Warn("Accessibility permission not granted. Input injection will not work. " +
"Grant in System Settings > Privacy & Security > Accessibility.")
log.Debugf("accessibility check output: %s (%v)", strings.TrimSpace(string(out)), err)
}
log.Info("Screen Recording permission is required for screen capture. " +
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
}
// InjectKey simulates a key press or release.
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
src := ensureEventSource()
if src == 0 {
return
}
keycode := keysymToMacKeycode(keysym)
if keycode == 0xFFFF {
return
}
event := cgEventCreateKeyboardEvent(src, keycode, down)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
// InjectPointer simulates mouse movement and button events.
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
src := ensureEventSource()
if src == 0 {
return
}
x := float64(px)
y := float64(py)
leftDown := buttonMask&0x01 != 0
rightDown := buttonMask&0x04 != 0
middleDown := buttonMask&0x02 != 0
scrollUp := buttonMask&0x08 != 0
scrollDown := buttonMask&0x10 != 0
wasLeft := m.lastButtons&0x01 != 0
wasRight := m.lastButtons&0x04 != 0
wasMiddle := m.lastButtons&0x02 != 0
if leftDown {
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
} else if rightDown {
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
} else {
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
}
if leftDown && !wasLeft {
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
} else if !leftDown && wasLeft {
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
}
if rightDown && !wasRight {
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
} else if !rightDown && wasRight {
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
}
if middleDown && !wasMiddle {
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
} else if !middleDown && wasMiddle {
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
}
if scrollUp {
m.postScroll(src, 3)
}
if scrollDown {
m.postScroll(src, -3)
}
m.lastButtons = buttonMask
}
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
if cgEventCreateMouseEvent == nil {
return
}
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
if cgEventCreateScrollWheelEventAddr == 0 {
return
}
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
// Variadic C function: pass args as uintptr via SyscallN.
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
src, 0, 1, uintptr(uint32(deltaY)))
if r1 == 0 {
return
}
cgEventPost(kCGHIDEventTap, r1)
cfRelease(r1)
}
// SetClipboard sets the macOS clipboard using pbcopy.
func (m *MacInputInjector) SetClipboard(text string) {
if m.pbcopyPath == "" {
return
}
cmd := exec.Command(m.pbcopyPath)
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Tracef("set clipboard via pbcopy: %v", err)
}
}
// GetClipboard reads the macOS clipboard using pbpaste.
func (m *MacInputInjector) GetClipboard() string {
if m.pbpastePath == "" {
return ""
}
out, err := exec.Command(m.pbpastePath).Output()
if err != nil {
log.Tracef("get clipboard via pbpaste: %v", err)
return ""
}
return string(out)
}
// Close is a no-op on macOS.
func (m *MacInputInjector) Close() {}
func keysymToMacKeycode(keysym uint32) uint16 {
if keysym >= 0x61 && keysym <= 0x7a {
return asciiToMacKey[keysym-0x61]
}
if keysym >= 0x41 && keysym <= 0x5a {
return asciiToMacKey[keysym-0x41]
}
if keysym >= 0x30 && keysym <= 0x39 {
return digitToMacKey[keysym-0x30]
}
if code, ok := specialKeyMap[keysym]; ok {
return code
}
return 0xFFFF
}
var asciiToMacKey = [26]uint16{
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
0x10, 0x06,
}
var digitToMacKey = [10]uint16{
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
}
var specialKeyMap = map[uint32]uint16{
// Whitespace and editing
0x0020: 0x31, // space
0xff08: 0x33, // BackSpace
0xff09: 0x30, // Tab
0xff0d: 0x24, // Return
0xff1b: 0x35, // Escape
0xffff: 0x75, // Delete (forward)
// Navigation
0xff50: 0x73, // Home
0xff51: 0x7B, // Left
0xff52: 0x7E, // Up
0xff53: 0x7C, // Right
0xff54: 0x7D, // Down
0xff55: 0x74, // Page_Up
0xff56: 0x79, // Page_Down
0xff57: 0x77, // End
0xff63: 0x72, // Insert (Help on Mac)
// Modifiers
0xffe1: 0x38, // Shift_L
0xffe2: 0x3C, // Shift_R
0xffe3: 0x3B, // Control_L
0xffe4: 0x3E, // Control_R
0xffe5: 0x39, // Caps_Lock
0xffe9: 0x3A, // Alt_L (Option)
0xffea: 0x3D, // Alt_R (Option)
0xffe7: 0x37, // Meta_L (Command)
0xffe8: 0x36, // Meta_R (Command)
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
0xffec: 0x36, // Super_R (Command)
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
0xff7e: 0x3A, // Mode_switch -> Option
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
// Function keys
0xffbe: 0x7A, // F1
0xffbf: 0x78, // F2
0xffc0: 0x63, // F3
0xffc1: 0x76, // F4
0xffc2: 0x60, // F5
0xffc3: 0x61, // F6
0xffc4: 0x62, // F7
0xffc5: 0x64, // F8
0xffc6: 0x65, // F9
0xffc7: 0x6D, // F10
0xffc8: 0x67, // F11
0xffc9: 0x6F, // F12
0xffca: 0x69, // F13
0xffcb: 0x6B, // F14
0xffcc: 0x71, // F15
0xffcd: 0x6A, // F16
0xffce: 0x40, // F17
0xffcf: 0x4F, // F18
0xffd0: 0x50, // F19
0xffd1: 0x5A, // F20
// Punctuation (US keyboard layout, keysym = ASCII code)
0x002d: 0x1B, // minus -
0x003d: 0x18, // equal =
0x005b: 0x21, // bracketleft [
0x005d: 0x1E, // bracketright ]
0x005c: 0x2A, // backslash
0x003b: 0x29, // semicolon ;
0x0027: 0x27, // apostrophe '
0x0060: 0x32, // grave `
0x002c: 0x2B, // comma ,
0x002e: 0x2F, // period .
0x002f: 0x2C, // slash /
// Shifted punctuation (noVNC sends these as separate keysyms)
0x005f: 0x1B, // underscore _ (shift+minus)
0x002b: 0x18, // plus + (shift+equal)
0x007b: 0x21, // braceleft { (shift+[)
0x007d: 0x1E, // braceright } (shift+])
0x007c: 0x2A, // bar | (shift+\)
0x003a: 0x29, // colon : (shift+;)
0x0022: 0x27, // quotedbl " (shift+')
0x007e: 0x32, // tilde ~ (shift+`)
0x003c: 0x2B, // less < (shift+,)
0x003e: 0x2F, // greater > (shift+.)
0x003f: 0x2C, // question ? (shift+/)
0x0021: 0x12, // exclam ! (shift+1)
0x0040: 0x13, // at @ (shift+2)
0x0023: 0x14, // numbersign # (shift+3)
0x0024: 0x15, // dollar $ (shift+4)
0x0025: 0x17, // percent % (shift+5)
0x005e: 0x16, // asciicircum ^ (shift+6)
0x0026: 0x1A, // ampersand & (shift+7)
0x002a: 0x1C, // asterisk * (shift+8)
0x0028: 0x19, // parenleft ( (shift+9)
0x0029: 0x1D, // parenright ) (shift+0)
// Numpad
0xffb0: 0x52, // KP_0
0xffb1: 0x53, // KP_1
0xffb2: 0x54, // KP_2
0xffb3: 0x55, // KP_3
0xffb4: 0x56, // KP_4
0xffb5: 0x57, // KP_5
0xffb6: 0x58, // KP_6
0xffb7: 0x59, // KP_7
0xffb8: 0x5B, // KP_8
0xffb9: 0x5C, // KP_9
0xffae: 0x41, // KP_Decimal
0xffaa: 0x43, // KP_Multiply
0xffab: 0x45, // KP_Add
0xffad: 0x4E, // KP_Subtract
0xffaf: 0x4B, // KP_Divide
0xff8d: 0x4C, // KP_Enter
0xffbd: 0x51, // KP_Equal
}
var _ InputInjector = (*MacInputInjector)(nil)

View File

@@ -0,0 +1,398 @@
//go:build windows
package server
import (
"runtime"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
procOpenEventW = kernel32.NewProc("OpenEventW")
procSendInput = user32.NewProc("SendInput")
procVkKeyScanA = user32.NewProc("VkKeyScanA")
)
const eventModifyState = 0x0002
const (
inputMouse = 0
inputKeyboard = 1
mouseeventfMove = 0x0001
mouseeventfLeftDown = 0x0002
mouseeventfLeftUp = 0x0004
mouseeventfRightDown = 0x0008
mouseeventfRightUp = 0x0010
mouseeventfMiddleDown = 0x0020
mouseeventfMiddleUp = 0x0040
mouseeventfWheel = 0x0800
mouseeventfAbsolute = 0x8000
wheelDelta = 120
keyeventfKeyUp = 0x0002
keyeventfScanCode = 0x0008
)
type mouseInput struct {
Dx int32
Dy int32
MouseData uint32
DwFlags uint32
Time uint32
DwExtraInfo uintptr
}
type keybdInput struct {
WVk uint16
WScan uint16
DwFlags uint32
Time uint32
DwExtraInfo uintptr
_ [8]byte
}
type inputUnion [32]byte
type winInput struct {
Type uint32
_ [4]byte
Data inputUnion
}
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
mi := mouseInput{
Dx: dx,
Dy: dy,
MouseData: mouseData,
DwFlags: flags,
}
inp := winInput{Type: inputMouse}
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
if r == 0 {
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
}
}
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
ki := keybdInput{
WVk: vk,
WScan: scanCode,
DwFlags: flags,
}
inp := winInput{Type: inputKeyboard}
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
if r == 0 {
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
}
}
const sasEventName = `Global\NetBirdVNC_SAS`
type inputCmd struct {
isKey bool
keysym uint32
down bool
buttonMask uint8
x, y int
serverW int
serverH int
}
// WindowsInputInjector delivers input events from a dedicated OS thread that
// calls switchToInputDesktop before each injection. SendInput targets the
// calling thread's desktop, so the injection thread must be on the same
// desktop the user sees.
type WindowsInputInjector struct {
ch chan inputCmd
prevButtonMask uint8
ctrlDown bool
altDown bool
}
// NewWindowsInputInjector creates a desktop-aware input injector.
func NewWindowsInputInjector() *WindowsInputInjector {
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
go w.loop()
return w
}
func (w *WindowsInputInjector) loop() {
runtime.LockOSThread()
for cmd := range w.ch {
// Switch to the current input desktop so SendInput reaches the right target.
switchToInputDesktop()
if cmd.isKey {
w.doInjectKey(cmd.keysym, cmd.down)
} else {
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
}
}
}
// InjectKey queues a key event for injection on the input desktop thread.
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
}
// InjectPointer queues a pointer event for injection on the input desktop thread.
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
}
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
switch keysym {
case 0xffe3, 0xffe4:
w.ctrlDown = down
case 0xffe9, 0xffea:
w.altDown = down
}
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
signalSAS()
return
}
vk, _, extended := keysym2VK(keysym)
if vk == 0 {
return
}
var flags uint32
if !down {
flags |= keyeventfKeyUp
}
if extended {
flags |= keyeventfScanCode
}
sendKeyInput(vk, 0, flags)
}
// signalSAS signals the SAS named event. A listener in Session 0
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
func signalSAS() {
namePtr, err := windows.UTF16PtrFromString(sasEventName)
if err != nil {
log.Warnf("SAS UTF16: %v", err)
return
}
h, _, lerr := procOpenEventW.Call(
uintptr(eventModifyState),
0,
uintptr(unsafe.Pointer(namePtr)),
)
if h == 0 {
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
return
}
ev := windows.Handle(h)
defer windows.CloseHandle(ev)
if err := windows.SetEvent(ev); err != nil {
log.Warnf("SetEvent SAS: %v", err)
} else {
log.Info("SAS event signaled")
}
}
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
absX := int32(x * 65535 / serverW)
absY := int32(y * 65535 / serverH)
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
changed := buttonMask ^ w.prevButtonMask
w.prevButtonMask = buttonMask
type btnMap struct {
bit uint8
down uint32
up uint32
}
buttons := [...]btnMap{
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
{0x04, mouseeventfRightDown, mouseeventfRightUp},
}
for _, b := range buttons {
if changed&b.bit == 0 {
continue
}
var flags uint32
if buttonMask&b.bit != 0 {
flags = b.down
} else {
flags = b.up
}
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
}
negWheelDelta := ^uint32(wheelDelta - 1)
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
}
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
}
}
// keysym2VK converts an X11 keysym to a Windows virtual key code.
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
if keysym >= 0x20 && keysym <= 0x7e {
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
vk = uint16(r & 0xff)
return
}
if keysym >= 0xffbe && keysym <= 0xffc9 {
vk = uint16(0x70 + keysym - 0xffbe)
return
}
switch keysym {
case 0xff08:
vk = 0x08 // Backspace
case 0xff09:
vk = 0x09 // Tab
case 0xff0d:
vk = 0x0d // Return
case 0xff1b:
vk = 0x1b // Escape
case 0xff63:
vk, extended = 0x2d, true // Insert
case 0xff9f, 0xffff:
vk, extended = 0x2e, true // Delete
case 0xff50:
vk, extended = 0x24, true // Home
case 0xff57:
vk, extended = 0x23, true // End
case 0xff55:
vk, extended = 0x21, true // PageUp
case 0xff56:
vk, extended = 0x22, true // PageDown
case 0xff51:
vk, extended = 0x25, true // Left
case 0xff52:
vk, extended = 0x26, true // Up
case 0xff53:
vk, extended = 0x27, true // Right
case 0xff54:
vk, extended = 0x28, true // Down
case 0xffe1, 0xffe2:
vk = 0x10 // Shift
case 0xffe3, 0xffe4:
vk = 0x11 // Control
case 0xffe9, 0xffea:
vk = 0x12 // Alt
case 0xffe5:
vk = 0x14 // CapsLock
case 0xffe7, 0xffeb:
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
case 0xffe8, 0xffec:
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
case 0xff61:
vk = 0x2c // PrintScreen
case 0xff13:
vk = 0x13 // Pause
case 0xff14:
vk = 0x91 // ScrollLock
}
return
}
var (
procOpenClipboard = user32.NewProc("OpenClipboard")
procCloseClipboard = user32.NewProc("CloseClipboard")
procEmptyClipboard = user32.NewProc("EmptyClipboard")
procSetClipboardData = user32.NewProc("SetClipboardData")
procGetClipboardData = user32.NewProc("GetClipboardData")
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
procGlobalLock = kernel32.NewProc("GlobalLock")
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
)
const (
cfUnicodeText = 13
gmemMoveable = 0x0002
)
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
func (w *WindowsInputInjector) SetClipboard(text string) {
utf16, err := windows.UTF16FromString(text)
if err != nil {
log.Tracef("clipboard UTF16 encode: %v", err)
return
}
size := uintptr(len(utf16) * 2)
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
if hMem == 0 {
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
return
}
ptr, _, _ := procGlobalLock.Call(hMem)
if ptr == 0 {
log.Tracef("GlobalLock for clipboard: lock returned nil")
return
}
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
procGlobalUnlock.Call(hMem)
r, _, lerr := procOpenClipboard.Call(0)
if r == 0 {
log.Tracef("OpenClipboard: %v", lerr)
return
}
defer procCloseClipboard.Call()
procEmptyClipboard.Call()
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
if r == 0 {
log.Tracef("SetClipboardData: %v", lerr)
}
}
// GetClipboard reads the Windows clipboard as UTF-8 text.
func (w *WindowsInputInjector) GetClipboard() string {
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
if r == 0 {
return ""
}
r, _, lerr := procOpenClipboard.Call(0)
if r == 0 {
log.Tracef("OpenClipboard for read: %v", lerr)
return ""
}
defer procCloseClipboard.Call()
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
if hData == 0 {
return ""
}
ptr, _, _ := procGlobalLock.Call(hData)
if ptr == 0 {
return ""
}
defer procGlobalUnlock.Call(hData)
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
}
var _ InputInjector = (*WindowsInputInjector)(nil)
var _ ScreenCapturer = (*DesktopCapturer)(nil)

View File

@@ -0,0 +1,242 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"os"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
"github.com/jezek/xgb/xtest"
)
// X11InputInjector injects keyboard and mouse events via the XTest extension.
type X11InputInjector struct {
conn *xgb.Conn
root xproto.Window
screen *xproto.ScreenInfo
display string
keysymMap map[uint32]byte
lastButtons uint8
clipboardTool string
clipboardToolName string
}
// NewX11InputInjector connects to the X11 display and initializes XTest.
func NewX11InputInjector(display string) (*X11InputInjector, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
if err := xtest.Init(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("init XTest extension: %w", err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
inj := &X11InputInjector{
conn: conn,
root: screen.Root,
screen: &screen,
display: display,
}
inj.cacheKeyboardMapping()
inj.resolveClipboardTool()
log.Infof("X11 input injector ready (display=%s)", display)
return inj, nil
}
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
keycode := x.keysymToKeycode(keysym)
if keycode == 0 {
return
}
var eventType byte
if down {
eventType = xproto.KeyPress
} else {
eventType = xproto.KeyRelease
}
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
}
// InjectPointer simulates mouse movement and button events.
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
// Scale to actual screen coordinates.
screenW := int(x.screen.WidthInPixels)
screenH := int(x.screen.HeightInPixels)
absX := px * screenW / serverW
absY := py * screenH / serverH
// Move pointer.
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
// 4=scrollUp, 5=scrollDown.
type btnMap struct {
rfbBit uint8
x11Btn byte
}
buttons := [...]btnMap{
{0x01, 1}, // left
{0x02, 2}, // middle
{0x04, 3}, // right
{0x08, 4}, // scroll up
{0x10, 5}, // scroll down
}
for _, b := range buttons {
pressed := buttonMask&b.rfbBit != 0
wasPressed := x.lastButtons&b.rfbBit != 0
if b.x11Btn >= 4 {
// Scroll: send press+release on each scroll event.
if pressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
}
} else {
if pressed && !wasPressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
} else if !pressed && wasPressed {
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
}
}
}
x.lastButtons = buttonMask
}
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
func (x *X11InputInjector) cacheKeyboardMapping() {
setup := xproto.Setup(x.conn)
minKeycode := setup.MinKeycode
maxKeycode := setup.MaxKeycode
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
byte(maxKeycode-minKeycode+1)).Reply()
if err != nil {
log.Debugf("cache keyboard mapping: %v", err)
x.keysymMap = make(map[uint32]byte)
return
}
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
for i := int(minKeycode); i <= int(maxKeycode); i++ {
offset := (i - int(minKeycode)) * keysymsPerKeycode
for j := 0; j < keysymsPerKeycode; j++ {
ks := uint32(reply.Keysyms[offset+j])
if ks != 0 {
if _, exists := m[ks]; !exists {
m[ks] = byte(i)
}
}
}
}
x.keysymMap = m
}
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
// Returns 0 if the keysym is not mapped.
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
return x.keysymMap[keysym]
}
// SetClipboard sets the X11 clipboard using xclip or xsel.
func (x *X11InputInjector) SetClipboard(text string) {
if x.clipboardTool == "" {
return
}
var cmd *exec.Cmd
if x.clipboardToolName == "xclip" {
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
} else {
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
}
cmd.Env = x.clipboardEnv()
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
}
}
func (x *X11InputInjector) resolveClipboardTool() {
for _, name := range []string{"xclip", "xsel"} {
path, err := exec.LookPath(name)
if err == nil {
x.clipboardTool = path
x.clipboardToolName = name
log.Debugf("clipboard tool resolved to %s", path)
return
}
}
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
}
// GetClipboard reads the X11 clipboard using xclip or xsel.
func (x *X11InputInjector) GetClipboard() string {
if x.clipboardTool == "" {
return ""
}
var cmd *exec.Cmd
if x.clipboardToolName == "xclip" {
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
} else {
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
}
cmd.Env = x.clipboardEnv()
out, err := cmd.Output()
if err != nil {
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
return ""
}
return string(out)
}
func (x *X11InputInjector) clipboardEnv() []string {
env := []string{"DISPLAY=" + x.display}
if auth := os.Getenv("XAUTHORITY"); auth != "" {
env = append(env, "XAUTHORITY="+auth)
}
return env
}
// Close releases X11 resources.
func (x *X11InputInjector) Close() {
x.conn.Close()
}
var _ InputInjector = (*X11InputInjector)(nil)
var _ ScreenCapturer = (*X11Poller)(nil)

264
client/vnc/server/rfb.go Normal file
View File

@@ -0,0 +1,264 @@
package server
import (
"bytes"
"compress/zlib"
"crypto/des"
"encoding/binary"
"image"
)
const (
rfbProtocolVersion = "RFB 003.008\n"
secNone = 1
secVNCAuth = 2
// Client message types.
clientSetPixelFormat = 0
clientSetEncodings = 2
clientFramebufferUpdateRequest = 3
clientKeyEvent = 4
clientPointerEvent = 5
clientCutText = 6
// Server message types.
serverFramebufferUpdate = 0
serverCutText = 3
// Encoding types.
encRaw = 0
encZlib = 6
)
// serverPixelFormat is the default pixel format advertised by the server:
// 32bpp RGBA, big-endian, true-colour, 8 bits per channel.
var serverPixelFormat = [16]byte{
32, // bits-per-pixel
24, // depth
1, // big-endian-flag
1, // true-colour-flag
0, 255, // red-max
0, 255, // green-max
0, 255, // blue-max
16, // red-shift
8, // green-shift
0, // blue-shift
0, 0, 0, // padding
}
// clientPixelFormat holds the negotiated pixel format from the client.
type clientPixelFormat struct {
bpp uint8
bigEndian uint8
rMax uint16
gMax uint16
bMax uint16
rShift uint8
gShift uint8
bShift uint8
}
func defaultClientPixelFormat() clientPixelFormat {
return clientPixelFormat{
bpp: serverPixelFormat[0],
bigEndian: serverPixelFormat[2],
rMax: binary.BigEndian.Uint16(serverPixelFormat[4:6]),
gMax: binary.BigEndian.Uint16(serverPixelFormat[6:8]),
bMax: binary.BigEndian.Uint16(serverPixelFormat[8:10]),
rShift: serverPixelFormat[10],
gShift: serverPixelFormat[11],
bShift: serverPixelFormat[12],
}
}
func parsePixelFormat(pf []byte) clientPixelFormat {
return clientPixelFormat{
bpp: pf[0],
bigEndian: pf[2],
rMax: binary.BigEndian.Uint16(pf[4:6]),
gMax: binary.BigEndian.Uint16(pf[6:8]),
bMax: binary.BigEndian.Uint16(pf[8:10]),
rShift: pf[10],
gShift: pf[11],
bShift: pf[12],
}
}
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
bytesPerPixel := max(int(pf.bpp)/8, 1)
pixelBytes := w * h * bytesPerPixel
buf := make([]byte, 4+12+pixelBytes)
// FramebufferUpdate header.
buf[0] = serverFramebufferUpdate
buf[1] = 0 // padding
binary.BigEndian.PutUint16(buf[2:4], 1)
// Rectangle header.
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
off := 16
stride := img.Stride
for row := y; row < y+h; row++ {
for col := x; col < x+w; col++ {
p := row*stride + col*4
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
rv := uint32(r) * uint32(pf.rMax) / 255
gv := uint32(g) * uint32(pf.gMax) / 255
bv := uint32(b) * uint32(pf.bMax) / 255
pixel := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
if pf.bigEndian != 0 {
for i := range bytesPerPixel {
buf[off+i] = byte(pixel >> uint((bytesPerPixel-1-i)*8))
}
} else {
for i := range bytesPerPixel {
buf[off+i] = byte(pixel >> uint(i*8))
}
}
off += bytesPerPixel
}
}
return buf
}
// vncAuthEncrypt encrypts a 16-byte challenge using the VNC DES scheme.
func vncAuthEncrypt(challenge []byte, password string) []byte {
key := make([]byte, 8)
for i, c := range []byte(password) {
if i >= 8 {
break
}
key[i] = reverseBits(c)
}
block, _ := des.NewCipher(key)
out := make([]byte, 16)
block.Encrypt(out[:8], challenge[:8])
block.Encrypt(out[8:], challenge[8:])
return out
}
func reverseBits(b byte) byte {
var r byte
for range 8 {
r = (r << 1) | (b & 1)
b >>= 1
}
return r
}
// encodeZlibRect encodes a framebuffer region using Zlib compression.
// The zlib stream is continuous for the entire VNC session: noVNC creates
// one inflate context at startup and reuses it for all zlib-encoded rects.
// We must NOT reset the zlib writer between calls.
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, zw *zlib.Writer, zbuf *bytes.Buffer) []byte {
bytesPerPixel := max(int(pf.bpp)/8, 1)
// Clear the output buffer but keep the deflate dictionary intact.
zbuf.Reset()
stride := img.Stride
pixel := make([]byte, bytesPerPixel)
for row := y; row < y+h; row++ {
for col := x; col < x+w; col++ {
p := row*stride + col*4
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
rv := uint32(r) * uint32(pf.rMax) / 255
gv := uint32(g) * uint32(pf.gMax) / 255
bv := uint32(b) * uint32(pf.bMax) / 255
val := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
if pf.bigEndian != 0 {
for i := range bytesPerPixel {
pixel[i] = byte(val >> uint((bytesPerPixel-1-i)*8))
}
} else {
for i := range bytesPerPixel {
pixel[i] = byte(val >> uint(i*8))
}
}
zw.Write(pixel)
}
}
zw.Flush()
compressed := zbuf.Bytes()
// Build the FramebufferUpdate message.
buf := make([]byte, 4+12+4+len(compressed))
buf[0] = serverFramebufferUpdate
buf[1] = 0
binary.BigEndian.PutUint16(buf[2:4], 1) // 1 rectangle
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
copy(buf[20:], compressed)
return buf
}
// diffRects compares two RGBA images and returns a list of dirty rectangles.
// Divides the screen into tiles and checks each for changes.
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
if prev == nil {
return [][4]int{{0, 0, w, h}}
}
var rects [][4]int
for ty := 0; ty < h; ty += tileSize {
th := min(tileSize, h-ty)
for tx := 0; tx < w; tx += tileSize {
tw := min(tileSize, w-tx)
if tileChanged(prev, cur, tx, ty, tw, th) {
rects = append(rects, [4]int{tx, ty, tw, th})
}
}
}
return rects
}
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
stride := prev.Stride
for row := y; row < y+h; row++ {
off := row*stride + x*4
end := off + w*4
prevRow := prev.Pix[off:end]
curRow := cur.Pix[off:end]
if !bytes.Equal(prevRow, curRow) {
return true
}
}
return false
}
// zlibState holds the persistent zlib writer and buffer for a session.
type zlibState struct {
buf *bytes.Buffer
w *zlib.Writer
}
func newZlibState() *zlibState {
buf := &bytes.Buffer{}
w, _ := zlib.NewWriterLevel(buf, zlib.BestSpeed)
return &zlibState{buf: buf, w: w}
}
func (z *zlibState) Close() error {
return z.w.Close()
}

605
client/vnc/server/server.go Normal file
View File

@@ -0,0 +1,605 @@
package server
import (
"context"
"crypto/subtle"
"encoding/binary"
"encoding/hex"
"fmt"
"image"
"io"
"net"
"net/netip"
"sync"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Connection modes sent by the client in the session header.
const (
ModeAttach byte = 0 // Capture current display
ModeSession byte = 1 // Virtual session as specified user
)
// ScreenCapturer grabs desktop frames for the VNC server.
type ScreenCapturer interface {
// Width returns the current screen width in pixels.
Width() int
// Height returns the current screen height in pixels.
Height() int
// Capture returns the current desktop as an RGBA image.
Capture() (*image.RGBA, error)
}
// InputInjector delivers keyboard and mouse events to the OS.
type InputInjector interface {
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
InjectKey(keysym uint32, down bool)
// InjectPointer simulates mouse movement and button state.
InjectPointer(buttonMask uint8, x, y, serverW, serverH int)
// SetClipboard sets the system clipboard to the given text.
SetClipboard(text string)
// GetClipboard returns the current system clipboard text.
GetClipboard() string
}
// JWTConfig holds JWT validation configuration for VNC auth.
type JWTConfig struct {
Issuer string
KeysLocation string
MaxTokenAge int64
Audiences []string
}
// connectionHeader is sent by the client before the RFB handshake to specify
// the VNC session mode and authenticate.
type connectionHeader struct {
mode byte
username string
jwt string
sessionID uint32 // Windows session ID (0 = console/auto)
}
// Server is the embedded VNC server that listens on the WireGuard interface.
// It supports two operating modes:
// - Direct mode: captures the screen and handles VNC sessions in-process.
// Used when running in a user session with desktop access.
// - Service mode: proxies VNC connections to an agent process spawned in
// the active console session. Used when running as a Windows service in
// Session 0.
//
// Within direct mode, each connection can request one of two session modes
// via the connection header:
// - Attach: capture the current physical display.
// - Session: start a virtual Xvfb display as the requested user.
type Server struct {
capturer ScreenCapturer
injector InputInjector
password string
serviceMode bool
disableAuth bool
localAddr netip.Addr // NetBird WireGuard IP this server is bound to
network netip.Prefix // NetBird overlay network
log *log.Entry
mu sync.Mutex
listener net.Listener
ctx context.Context
cancel context.CancelFunc
vmgr virtualSessionManager
jwtConfig *JWTConfig
jwtValidator *nbjwt.Validator
jwtExtractor *nbjwt.ClaimsExtractor
authorizer *sshauth.Authorizer
netstackNet *netstack.Net
agentToken []byte // raw token bytes for agent-mode auth
}
// vncSession provides capturer and injector for a virtual display session.
type vncSession interface {
Capturer() ScreenCapturer
Injector() InputInjector
Display() string
ClientConnect()
ClientDisconnect()
}
// virtualSessionManager is implemented by sessionManager on Linux.
type virtualSessionManager interface {
GetOrCreate(username string) (vncSession, error)
StopAll()
}
// New creates a VNC server with the given screen capturer and input injector.
func New(capturer ScreenCapturer, injector InputInjector, password string) *Server {
return &Server{
capturer: capturer,
injector: injector,
password: password,
authorizer: sshauth.NewAuthorizer(),
log: log.WithField("component", "vnc-server"),
}
}
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
func (s *Server) SetServiceMode(enabled bool) {
s.serviceMode = enabled
}
// SetJWTConfig configures JWT authentication for VNC connections.
// Pass nil to disable JWT (public mode).
func (s *Server) SetJWTConfig(config *JWTConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtConfig = config
s.jwtValidator = nil
s.jwtExtractor = nil
}
// SetDisableAuth disables authentication entirely.
func (s *Server) SetDisableAuth(disable bool) {
s.disableAuth = disable
}
// SetAgentToken sets a hex-encoded token that must be presented by incoming
// connections before any VNC data. Used in agent mode to verify that only the
// trusted service process connects.
func (s *Server) SetAgentToken(hexToken string) {
if hexToken == "" {
return
}
b, err := hex.DecodeString(hexToken)
if err != nil {
s.log.Warnf("invalid agent token: %v", err)
return
}
s.agentToken = b
}
// SetNetstackNet sets the netstack network for userspace-only listening.
// When set, the VNC server listens via netstack instead of a real OS socket.
func (s *Server) SetNetstackNet(n *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = n
}
// UpdateVNCAuth updates the fine-grained authorization configuration.
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// Start begins listening for VNC connections on the given address.
// network is the NetBird overlay prefix used to validate connection sources.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return fmt.Errorf("server already running")
}
s.ctx, s.cancel = context.WithCancel(ctx)
s.vmgr = s.platformSessionManager()
s.localAddr = addr.Addr()
s.network = network
var listener net.Listener
var listenDesc string
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
if err != nil {
return fmt.Errorf("listen on netstack %s: %w", addr, err)
}
listener = ln
listenDesc = fmt.Sprintf("netstack %s", addr)
} else {
tcpAddr := net.TCPAddrFromAddrPort(addr)
ln, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
listener = ln
listenDesc = addr.String()
}
s.listener = listener
if s.serviceMode {
s.platformInit()
}
if s.serviceMode {
go s.serviceAcceptLoop()
} else {
go s.acceptLoop()
}
s.log.Infof("started on %s (service_mode=%v)", listenDesc, s.serviceMode)
return nil
}
// Stop shuts down the server and closes all connections.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
s.cancel = nil
}
if s.vmgr != nil {
s.vmgr.StopAll()
}
if c, ok := s.capturer.(interface{ Close() }); ok {
c.Close()
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
if err != nil {
return fmt.Errorf("close VNC listener: %w", err)
}
}
s.log.Info("stopped")
return nil
}
// acceptLoop handles VNC connections directly (user session mode).
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
}
s.log.Debugf("accept VNC connection: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) validateCapturer(cap ScreenCapturer) error {
// Quick check first: if already ready, return immediately.
if cap.Width() > 0 && cap.Height() > 0 {
return nil
}
// Wait up to 5s for the capturer to become ready.
for range 50 {
time.Sleep(100 * time.Millisecond)
if cap.Width() > 0 && cap.Height() > 0 {
return nil
}
}
return fmt.Errorf("no display available (X11 not running or screen recording not permitted)")
}
// isAllowedSource rejects connections from outside the NetBird overlay network
// and from the local WireGuard IP (prevents local privilege escalation).
// Matches the SSH server's connectionValidator logic.
func (s *Server) isAllowedSource(addr net.Addr) bool {
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
s.log.Warnf("connection rejected: non-TCP address %s", addr)
return false
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
s.log.Warnf("connection rejected: invalid remote IP %s", tcpAddr.IP)
return false
}
remoteIP = remoteIP.Unmap()
if remoteIP.IsLoopback() && s.localAddr.IsLoopback() {
return true
}
if remoteIP == s.localAddr {
s.log.Warnf("connection rejected from own IP %s", remoteIP)
return false
}
if s.network.IsValid() && !s.network.Contains(remoteIP) {
s.log.Warnf("connection rejected from non-NetBird IP %s", remoteIP)
return false
}
return true
}
func (s *Server) handleConnection(conn net.Conn) {
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
conn.Close()
return
}
if len(s.agentToken) > 0 {
buf := make([]byte, len(s.agentToken))
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
connLog.Debugf("set agent token deadline: %v", err)
conn.Close()
return
}
if _, err := io.ReadFull(conn, buf); err != nil {
connLog.Warnf("agent auth: read token: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{}) //nolint:errcheck
if subtle.ConstantTimeCompare(buf, s.agentToken) != 1 {
connLog.Warn("agent auth: invalid token, rejecting")
conn.Close()
return
}
}
header, err := readConnectionHeader(conn)
if err != nil {
connLog.Warnf("read connection header: %v", err)
conn.Close()
return
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, "auth enabled but no identity provider configured")
connLog.Warn("auth rejected: no identity provider configured")
return
}
jwtUserID, err := s.authenticateJWT(header)
if err != nil {
rejectConnection(conn, fmt.Sprintf("auth: %v", err))
connLog.Warnf("auth rejected: %v", err)
return
}
connLog = connLog.WithField("jwt_user", jwtUserID)
}
var capturer ScreenCapturer
var injector InputInjector
switch header.mode {
case ModeSession:
if s.vmgr == nil {
rejectConnection(conn, "virtual sessions not supported on this platform")
connLog.Warn("session rejected: not supported on this platform")
return
}
if header.username == "" {
rejectConnection(conn, "session mode requires a username")
connLog.Warn("session rejected: no username provided")
return
}
vs, err := s.vmgr.GetOrCreate(header.username)
if err != nil {
rejectConnection(conn, fmt.Sprintf("create virtual session: %v", err))
connLog.Warnf("create virtual session for %s: %v", header.username, err)
return
}
capturer = vs.Capturer()
injector = vs.Injector()
vs.ClientConnect()
defer vs.ClientDisconnect()
connLog = connLog.WithField("vnc_user", header.username)
connLog.Infof("session mode: user=%s display=%s", header.username, vs.Display())
default:
capturer = s.capturer
injector = s.injector
if cc, ok := capturer.(interface{ ClientConnect() }); ok {
cc.ClientConnect()
}
defer func() {
if cd, ok := capturer.(interface{ ClientDisconnect() }); ok {
cd.ClientDisconnect()
}
}()
}
if err := s.validateCapturer(capturer); err != nil {
rejectConnection(conn, fmt.Sprintf("screen capturer: %v", err))
connLog.Warnf("capturer not ready: %v", err)
return
}
sess := &session{
conn: conn,
capturer: capturer,
injector: injector,
serverW: capturer.Width(),
serverH: capturer.Height(),
password: s.password,
log: connLog,
}
sess.serve()
}
// rejectConnection sends a minimal RFB handshake with a security failure
// reason, so VNC clients display the error message instead of a generic
// "unexpected disconnect."
func rejectConnection(conn net.Conn, reason string) {
defer conn.Close()
// RFB 3.8 server version.
io.WriteString(conn, "RFB 003.008\n")
// Read client version (12 bytes), ignore errors.
var clientVer [12]byte
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
io.ReadFull(conn, clientVer[:])
conn.SetReadDeadline(time.Time{})
// Send 0 security types = connection failed, followed by reason.
msg := []byte(reason)
buf := make([]byte, 1+4+len(msg))
buf[0] = 0 // 0 security types = failure
binary.BigEndian.PutUint32(buf[1:5], uint32(len(msg)))
copy(buf[5:], msg)
conn.Write(buf)
}
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
// authenticateJWT validates the JWT from the connection header and checks
// authorization. For attach mode, just checks membership in the authorized
// user list. For session mode, additionally validates the OS user mapping.
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
if header.jwt == "" {
return "", fmt.Errorf("JWT required but not provided")
}
s.mu.Lock()
if err := s.ensureJWTValidator(); err != nil {
s.mu.Unlock()
return "", fmt.Errorf("initialize JWT validator: %w", err)
}
validator := s.jwtValidator
extractor := s.jwtExtractor
s.mu.Unlock()
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
if err != nil {
return "", fmt.Errorf("validate JWT: %w", err)
}
if err := s.checkTokenAge(token); err != nil {
return "", err
}
userAuth, err := extractor.ToUserAuth(token)
if err != nil {
return "", fmt.Errorf("extract user from JWT: %w", err)
}
if userAuth.UserId == "" {
return "", fmt.Errorf("JWT has no user ID")
}
switch header.mode {
case ModeSession:
// Session mode: check user + OS username mapping.
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
}
default:
// Attach mode: just check user is in the authorized list (wildcard OS user).
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
return "", fmt.Errorf("user not authorized for VNC: %w", err)
}
}
return userAuth.UserId, nil
}
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
func (s *Server) ensureJWTValidator() error {
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
if s.jwtConfig == nil {
return fmt.Errorf("no JWT config")
}
s.jwtValidator = nbjwt.NewValidator(
s.jwtConfig.Issuer,
s.jwtConfig.Audiences,
s.jwtConfig.KeysLocation,
false,
)
opts := []nbjwt.ClaimsExtractorOption{nbjwt.WithAudience(s.jwtConfig.Audiences[0])}
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
opts = append(opts, nbjwt.WithUserIDClaim(claim))
}
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
return nil
}
func (s *Server) checkTokenAge(token *gojwt.Token) error {
maxAge := defaultJWTMaxTokenAge
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
maxAge = int(s.jwtConfig.MaxTokenAge)
}
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
}
// readConnectionHeader reads the NetBird VNC session header from the connection.
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
//
// [jwt_len: 2 bytes BE] [jwt: N bytes]
//
// Uses a short timeout: our WASM proxy sends the header immediately after
// connecting. Standard VNC clients don't send anything first (server speaks
// first in RFB), so they time out and get the default attach mode.
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
var hdr [3]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
// Timeout or error: assume no header, use attach mode.
return &connectionHeader{mode: ModeAttach}, nil
}
// Restore a longer deadline for reading variable-length fields.
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
mode := hdr[0]
usernameLen := binary.BigEndian.Uint16(hdr[1:3])
var username string
if usernameLen > 0 {
if usernameLen > 256 {
return nil, fmt.Errorf("username too long: %d", usernameLen)
}
buf := make([]byte, usernameLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read username: %w", err)
}
username = string(buf)
}
// Read JWT token length and data.
var jwtLenBuf [2]byte
var jwtToken string
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
if jwtLen > 0 && jwtLen < 8192 {
buf := make([]byte, jwtLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read JWT: %w", err)
}
jwtToken = string(buf)
}
}
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
var sessionID uint32
var sidBuf [4]byte
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
sessionID = binary.BigEndian.Uint32(sidBuf[:])
}
return &connectionHeader{mode: mode, username: username, jwt: jwtToken, sessionID: sessionID}, nil
}

View File

@@ -0,0 +1,15 @@
//go:build darwin && !ios
package server
func (s *Server) platformInit() {}
// serviceAcceptLoop is not supported on macOS.
func (s *Server) serviceAcceptLoop() {
s.log.Warn("service mode not supported on macOS, falling back to direct mode")
s.acceptLoop()
}
func (s *Server) platformSessionManager() virtualSessionManager {
return nil
}

View File

@@ -0,0 +1,15 @@
//go:build !windows && !darwin && !freebsd && !(linux && !android)
package server
func (s *Server) platformInit() {}
// serviceAcceptLoop is not supported on non-Windows platforms.
func (s *Server) serviceAcceptLoop() {
s.log.Warn("service mode not supported on this platform, falling back to direct mode")
s.acceptLoop()
}
func (s *Server) platformSessionManager() virtualSessionManager {
return nil
}

View File

@@ -0,0 +1,136 @@
package server
import (
"encoding/binary"
"image"
"io"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testCapturer returns a 100x100 image for test sessions.
type testCapturer struct{}
func (t *testCapturer) Width() int { return 100 }
func (t *testCapturer) Height() int { return 100 }
func (t *testCapturer) Capture() (*image.RGBA, error) { return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil }
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
t.Helper()
srv := New(&testCapturer{}, &StubInputInjector{}, "")
srv.SetDisableAuth(disableAuth)
if jwtConfig != nil {
srv.SetJWTConfig(jwtConfig)
}
addr := netip.MustParseAddrPort("127.0.0.1:0")
network := netip.MustParsePrefix("127.0.0.0/8")
require.NoError(t, srv.Start(t.Context(), addr, network))
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
srv.localAddr = netip.MustParseAddr("10.99.99.1")
t.Cleanup(func() { _ = srv.Stop() })
return srv.listener.Addr(), srv
}
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
addr, _ := startTestServer(t, false, nil)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header: attach mode, no username, no JWT.
header := []byte{ModeAttach, 0, 0, 0, 0}
_, err = conn.Write(header)
require.NoError(t, err)
// Server should send RFB version then security failure.
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
assert.Equal(t, "RFB 003.008\n", string(version[:]))
// Write client version to proceed through handshake.
_, err = conn.Write(version[:])
require.NoError(t, err)
// Read security types: 0 means failure, followed by reason.
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
var reasonLen [4]byte
_, err = io.ReadFull(conn, reasonLen[:])
require.NoError(t, err)
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
_, err = io.ReadFull(conn, reason)
require.NoError(t, err)
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
}
func TestAuthDisabled_AllowsConnection(t *testing.T) {
addr, _ := startTestServer(t, true, nil)
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header: attach mode, no username, no JWT.
header := []byte{ModeAttach, 0, 0, 0, 0}
_, err = conn.Write(header)
require.NoError(t, err)
// Server should send RFB version.
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
assert.Equal(t, "RFB 003.008\n", string(version[:]))
// Write client version.
_, err = conn.Write(version[:])
require.NoError(t, err)
// Should get security types (not 0 = failure).
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
}
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
addr, _ := startTestServer(t, false, &JWTConfig{
Issuer: "https://example.com",
KeysLocation: "https://example.com/.well-known/jwks.json",
Audiences: []string{"test"},
})
conn, err := net.Dial("tcp", addr.String())
require.NoError(t, err)
defer conn.Close()
// Send session header with empty JWT.
header := []byte{ModeAttach, 0, 0, 0, 0}
_, err = conn.Write(header)
require.NoError(t, err)
var version [12]byte
_, err = io.ReadFull(conn, version[:])
require.NoError(t, err)
_, err = conn.Write(version[:])
require.NoError(t, err)
var numTypes [1]byte
_, err = io.ReadFull(conn, numTypes[:])
require.NoError(t, err)
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
}

View File

@@ -0,0 +1,223 @@
//go:build windows
package server
import (
"bytes"
"fmt"
"io"
"net"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var (
sasDLL = windows.NewLazySystemDLL("sas.dll")
procSendSAS = sasDLL.NewProc("SendSAS")
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
)
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
// local processes from triggering the Secure Attention Sequence.
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
// to the interactive user (IU) so the VNC agent in the console session can
// signal it. Other local users and network users are denied.
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
if err != nil {
return nil, err
}
var sd uintptr
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
uintptr(unsafe.Pointer(sddl)),
1, // SDDL_REVISION_1
uintptr(unsafe.Pointer(&sd)),
0,
)
if r == 0 {
return nil, lerr
}
return &windows.SecurityAttributes{
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
InheritHandle: 0,
}, nil
}
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
// SendSAS silently does nothing on most Windows editions.
func enableSoftwareSAS() {
key, _, err := registry.CreateKey(
registry.LOCAL_MACHINE,
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
registry.SET_VALUE,
)
if err != nil {
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
return
}
defer key.Close()
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
log.Warnf("set SoftwareSASGeneration: %v", err)
return
}
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
}
// startSASListener creates a named event with a restricted DACL and waits for
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
// Only SYSTEM processes can open the event.
func startSASListener() {
enableSoftwareSAS()
namePtr, err := windows.UTF16PtrFromString(sasEventName)
if err != nil {
log.Warnf("SAS listener UTF16: %v", err)
return
}
sa, err := sasSecurityAttributes()
if err != nil {
log.Warnf("build SAS security descriptor: %v", err)
return
}
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
if err != nil {
log.Warnf("SAS CreateEvent: %v", err)
return
}
log.Info("SAS listener ready (Session 0)")
go func() {
defer windows.CloseHandle(ev)
for {
ret, _ := windows.WaitForSingleObject(ev, windows.INFINITE)
if ret == windows.WAIT_OBJECT_0 {
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
if r == 0 {
log.Warnf("SendSAS: %v", sasErr)
} else {
log.Info("SendSAS called from Session 0")
}
}
}
}()
}
// enablePrivilege enables a named privilege on the current process token.
func enablePrivilege(name string) error {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
return err
}
defer token.Close()
var luid windows.LUID
namePtr, _ := windows.UTF16PtrFromString(name)
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
return err
}
tp := windows.Tokenprivileges{PrivilegeCount: 1}
tp.Privileges[0].Luid = luid
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
}
func (s *Server) platformSessionManager() virtualSessionManager {
return nil
}
// platformInit starts the SAS listener and enables privileges needed for
// Session 0 operations (agent spawning, SendSAS).
func (s *Server) platformInit() {
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
if err := enablePrivilege(priv); err != nil {
log.Debugf("enable %s: %v", priv, err)
}
}
startSASListener()
}
// serviceAcceptLoop runs in Session 0. It validates source IP and
// authenticates via JWT before proxying connections to the user-session agent.
func (s *Server) serviceAcceptLoop() {
sm := newSessionManager(agentPort)
go sm.run()
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%s", agentPort)
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
sm.Stop()
return
default:
}
s.log.Debugf("accept VNC connection: %v", err)
continue
}
go s.handleServiceConnection(conn, sm)
}
}
// handleServiceConnection validates the source IP and JWT, then proxies
// the connection (with header bytes replayed) to the agent.
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
conn.Close()
return
}
var headerBuf bytes.Buffer
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := readConnectionHeader(teeConn)
if err != nil {
connLog.Debugf("read connection header: %v", err)
conn.Close()
return
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, "auth enabled but no identity provider configured")
connLog.Warn("auth rejected: no identity provider configured")
return
}
if _, err := s.authenticateJWT(header); err != nil {
rejectConnection(conn, fmt.Sprintf("auth: %v", err))
connLog.Warnf("auth rejected: %v", err)
return
}
}
// Replay buffered header bytes + remaining stream to the agent.
replayConn := &prefixConn{
Reader: io.MultiReader(&headerBuf, conn),
Conn: conn,
}
proxyToAgent(replayConn, agentPort, sm.AuthToken())
}
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
type prefixConn struct {
io.Reader
net.Conn
}
func (p *prefixConn) Read(b []byte) (int, error) {
return p.Reader.Read(b)
}

View File

@@ -0,0 +1,15 @@
//go:build (linux && !android) || freebsd
package server
func (s *Server) platformInit() {}
// serviceAcceptLoop is not supported on Linux.
func (s *Server) serviceAcceptLoop() {
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
s.acceptLoop()
}
func (s *Server) platformSessionManager() virtualSessionManager {
return newSessionManager(s.log)
}

View File

@@ -0,0 +1,443 @@
package server
import (
"bytes"
"crypto/rand"
"encoding/binary"
"fmt"
"image"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
readDeadline = 60 * time.Second
maxCutTextBytes = 1 << 20 // 1 MiB
)
const tileSize = 64 // pixels per tile for dirty-rect detection
type session struct {
conn net.Conn
capturer ScreenCapturer
injector InputInjector
serverW int
serverH int
password string
log *log.Entry
writeMu sync.Mutex
pf clientPixelFormat
useZlib bool
zlib *zlibState
prevFrame *image.RGBA
idleFrames int
}
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
// serve runs the full RFB session lifecycle.
func (s *session) serve() {
defer s.conn.Close()
s.pf = defaultClientPixelFormat()
if err := s.handshake(); err != nil {
s.log.Warnf("handshake with %s: %v", s.addr(), err)
return
}
s.log.Infof("client connected: %s", s.addr())
done := make(chan struct{})
defer close(done)
go s.clipboardPoll(done)
if err := s.messageLoop(); err != nil && err != io.EOF {
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
} else {
s.log.Infof("client disconnected: %s", s.addr())
}
}
// clipboardPoll periodically checks the server-side clipboard and sends
// changes to the VNC client. Only runs during active sessions.
func (s *session) clipboardPoll(done <-chan struct{}) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
var lastClip string
for {
select {
case <-done:
return
case <-ticker.C:
text := s.injector.GetClipboard()
if len(text) > maxCutTextBytes {
text = text[:maxCutTextBytes]
}
if text != "" && text != lastClip {
lastClip = text
if err := s.sendServerCutText(text); err != nil {
s.log.Debugf("send clipboard to client: %v", err)
return
}
}
}
}
}
func (s *session) handshake() error {
// Send protocol version.
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
return fmt.Errorf("send version: %w", err)
}
// Read client version.
var clientVer [12]byte
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
return fmt.Errorf("read client version: %w", err)
}
// Send supported security types.
if err := s.sendSecurityTypes(); err != nil {
return err
}
// Read chosen security type.
var secType [1]byte
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
return fmt.Errorf("read security type: %w", err)
}
if err := s.handleSecurity(secType[0]); err != nil {
return err
}
// Read ClientInit.
var clientInit [1]byte
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
return fmt.Errorf("read ClientInit: %w", err)
}
return s.sendServerInit()
}
func (s *session) sendSecurityTypes() error {
if s.password == "" {
_, err := s.conn.Write([]byte{1, secNone})
return err
}
_, err := s.conn.Write([]byte{1, secVNCAuth})
return err
}
func (s *session) handleSecurity(secType byte) error {
switch secType {
case secVNCAuth:
return s.doVNCAuth()
case secNone:
return binary.Write(s.conn, binary.BigEndian, uint32(0))
default:
return fmt.Errorf("unsupported security type: %d", secType)
}
}
func (s *session) doVNCAuth() error {
challenge := make([]byte, 16)
if _, err := rand.Read(challenge); err != nil {
return fmt.Errorf("generate challenge: %w", err)
}
if _, err := s.conn.Write(challenge); err != nil {
return fmt.Errorf("send challenge: %w", err)
}
response := make([]byte, 16)
if _, err := io.ReadFull(s.conn, response); err != nil {
return fmt.Errorf("read auth response: %w", err)
}
var result uint32
if s.password != "" {
expected := vncAuthEncrypt(challenge, s.password)
if !bytes.Equal(expected, response) {
result = 1
}
}
if err := binary.Write(s.conn, binary.BigEndian, result); err != nil {
return fmt.Errorf("send auth result: %w", err)
}
if result != 0 {
msg := "authentication failed"
_ = binary.Write(s.conn, binary.BigEndian, uint32(len(msg)))
_, _ = s.conn.Write([]byte(msg))
return fmt.Errorf("authentication failed from %s", s.addr())
}
return nil
}
func (s *session) sendServerInit() error {
name := []byte("NetBird VNC")
buf := make([]byte, 0, 4+16+4+len(name))
// Framebuffer width and height.
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
// Server pixel format.
buf = append(buf, serverPixelFormat[:]...)
// Desktop name.
buf = append(buf,
byte(len(name)>>24), byte(len(name)>>16),
byte(len(name)>>8), byte(len(name)),
)
buf = append(buf, name...)
_, err := s.conn.Write(buf)
return err
}
func (s *session) messageLoop() error {
for {
var msgType [1]byte
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
return fmt.Errorf("set deadline: %w", err)
}
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
return err
}
_ = s.conn.SetDeadline(time.Time{})
switch msgType[0] {
case clientSetPixelFormat:
if err := s.handleSetPixelFormat(); err != nil {
return err
}
case clientSetEncodings:
if err := s.handleSetEncodings(); err != nil {
return err
}
case clientFramebufferUpdateRequest:
if err := s.handleFBUpdateRequest(); err != nil {
return err
}
case clientKeyEvent:
if err := s.handleKeyEvent(); err != nil {
return err
}
case clientPointerEvent:
if err := s.handlePointerEvent(); err != nil {
return err
}
case clientCutText:
if err := s.handleCutText(); err != nil {
return err
}
default:
return fmt.Errorf("unknown client message type: %d", msgType[0])
}
}
}
func (s *session) handleSetPixelFormat() error {
var buf [19]byte // 3 padding + 16 pixel format
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
return fmt.Errorf("read SetPixelFormat: %w", err)
}
s.pf = parsePixelFormat(buf[3:19])
return nil
}
func (s *session) handleSetEncodings() error {
var header [3]byte // 1 padding + 2 number-of-encodings
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
return fmt.Errorf("read SetEncodings header: %w", err)
}
numEnc := binary.BigEndian.Uint16(header[1:3])
buf := make([]byte, int(numEnc)*4)
if _, err := io.ReadFull(s.conn, buf); err != nil {
return err
}
// Check if client supports zlib encoding.
for i := range int(numEnc) {
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
if enc == encZlib {
s.useZlib = true
if s.zlib == nil {
s.zlib = newZlibState()
}
s.log.Debugf("client supports zlib encoding")
break
}
}
return nil
}
func (s *session) handleFBUpdateRequest() error {
var req [9]byte
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
return fmt.Errorf("read FBUpdateRequest: %w", err)
}
incremental := req[0]
img, err := s.capturer.Capture()
if err != nil {
return fmt.Errorf("capture screen: %w", err)
}
if incremental == 1 && s.prevFrame != nil {
rects := diffRects(s.prevFrame, img, s.serverW, s.serverH, tileSize)
if len(rects) == 0 {
// Nothing changed. Back off briefly before responding to reduce
// CPU usage when the screen is static. The client re-requests
// immediately after receiving our empty response, so without
// this delay we'd spin at ~1000fps checking for changes.
s.idleFrames++
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
time.Sleep(time.Duration(delay) * time.Millisecond)
s.savePrevFrame(img)
return s.sendEmptyUpdate()
}
s.idleFrames = 0
s.savePrevFrame(img)
return s.sendDirtyRects(img, rects)
}
// Full update.
s.idleFrames = 0
s.savePrevFrame(img)
return s.sendFullUpdate(img)
}
// savePrevFrame copies img's pixel data into prevFrame. This is necessary
// because some capturers (DXGI) reuse the same image buffer across calls,
// so a simple pointer assignment would make prevFrame alias the live buffer
// and diffRects would always see zero changes.
func (s *session) savePrevFrame(img *image.RGBA) {
if s.prevFrame == nil || s.prevFrame.Rect != img.Rect {
s.prevFrame = image.NewRGBA(img.Rect)
}
copy(s.prevFrame.Pix, img.Pix)
}
// sendEmptyUpdate sends a FramebufferUpdate with zero rectangles.
func (s *session) sendEmptyUpdate() error {
var buf [4]byte
buf[0] = serverFramebufferUpdate
s.writeMu.Lock()
_, err := s.conn.Write(buf[:])
s.writeMu.Unlock()
return err
}
func (s *session) sendFullUpdate(img *image.RGBA) error {
w, h := s.serverW, s.serverH
var buf []byte
if s.useZlib && s.zlib != nil {
buf = encodeZlibRect(img, s.pf, 0, 0, w, h, s.zlib.w, s.zlib.buf)
} else {
buf = encodeRawRect(img, s.pf, 0, 0, w, h)
}
s.writeMu.Lock()
_, err := s.conn.Write(buf)
s.writeMu.Unlock()
return err
}
func (s *session) sendDirtyRects(img *image.RGBA, rects [][4]int) error {
// Build a multi-rectangle FramebufferUpdate.
// Header: type(1) + padding(1) + numRects(2)
header := make([]byte, 4)
header[0] = serverFramebufferUpdate
binary.BigEndian.PutUint16(header[2:4], uint16(len(rects)))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if _, err := s.conn.Write(header); err != nil {
return err
}
for _, r := range rects {
x, y, w, h := r[0], r[1], r[2], r[3]
var rectBuf []byte
if s.useZlib && s.zlib != nil {
rectBuf = encodeZlibRect(img, s.pf, x, y, w, h, s.zlib.w, s.zlib.buf)
// encodeZlibRect includes its own FBUpdate header for 1 rect.
// For multi-rect, we need just the rect data without the FBUpdate header.
// Skip the 4-byte FBUpdate header since we already sent ours.
rectBuf = rectBuf[4:]
} else {
rectBuf = encodeRawRect(img, s.pf, x, y, w, h)
rectBuf = rectBuf[4:] // skip FBUpdate header
}
if _, err := s.conn.Write(rectBuf); err != nil {
return err
}
}
return nil
}
func (s *session) handleKeyEvent() error {
var data [7]byte
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read KeyEvent: %w", err)
}
down := data[0] == 1
keysym := binary.BigEndian.Uint32(data[3:7])
s.injector.InjectKey(keysym, down)
return nil
}
func (s *session) handlePointerEvent() error {
var data [5]byte
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read PointerEvent: %w", err)
}
buttonMask := data[0]
x := int(binary.BigEndian.Uint16(data[1:3]))
y := int(binary.BigEndian.Uint16(data[3:5]))
s.injector.InjectPointer(buttonMask, x, y, s.serverW, s.serverH)
return nil
}
func (s *session) handleCutText() error {
var header [7]byte // 3 padding + 4 length
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
return fmt.Errorf("read CutText header: %w", err)
}
length := binary.BigEndian.Uint32(header[3:7])
if length > maxCutTextBytes {
return fmt.Errorf("cut text too large: %d bytes", length)
}
buf := make([]byte, length)
if _, err := io.ReadFull(s.conn, buf); err != nil {
return fmt.Errorf("read CutText payload: %w", err)
}
s.injector.SetClipboard(string(buf))
return nil
}
// sendServerCutText sends clipboard text from the server to the client.
func (s *session) sendServerCutText(text string) error {
data := []byte(text)
buf := make([]byte, 8+len(data))
buf[0] = serverCutText
// buf[1:4] = padding (zero)
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
copy(buf[8:], data)
s.writeMu.Lock()
_, err := s.conn.Write(buf)
s.writeMu.Unlock()
return err
}

View File

@@ -0,0 +1,79 @@
//go:build !windows
package server
import (
"fmt"
"os"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
)
// ShutdownState tracks VNC virtual session processes for crash recovery.
// Persisted by the state manager; on restart, residual processes are killed.
type ShutdownState struct {
// Processes maps a description to its PID (e.g., "xvfb:50" -> 1234).
Processes map[string]int `json:"processes,omitempty"`
}
// Name returns the state name for the state manager.
func (s *ShutdownState) Name() string {
return "vnc_sessions_state"
}
// Cleanup kills any residual VNC session processes left from a crash.
func (s *ShutdownState) Cleanup() error {
if len(s.Processes) == 0 {
return nil
}
for desc, pid := range s.Processes {
if pid <= 0 {
continue
}
if !isOurProcess(pid, desc) {
log.Debugf("cleanup:skipping PID %d (%s), not ours", pid, desc)
continue
}
log.Infof("cleanup:killing residual process %d (%s)", pid, desc)
// Kill the process group (negative PID) to get children too.
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
// Try individual process if group kill fails.
syscall.Kill(pid, syscall.SIGKILL)
}
}
s.Processes = nil
return nil
}
// isOurProcess verifies the PID still belongs to a VNC-related process
// by checking /proc/<pid>/cmdline (Linux) or the process name.
func isOurProcess(pid int, desc string) bool {
// Check if the process exists at all.
if err := syscall.Kill(pid, 0); err != nil {
return false
}
// On Linux, verify via /proc cmdline.
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil {
// No /proc (FreeBSD): trust the PID if the process exists.
// PID reuse is unlikely in the short window between crash and restart.
return true
}
cmd := string(cmdline)
// Match against expected process types.
if strings.Contains(desc, "xvfb") || strings.Contains(desc, "xorg") {
return strings.Contains(cmd, "Xvfb") || strings.Contains(cmd, "Xorg")
}
if strings.Contains(desc, "desktop") {
return strings.Contains(cmd, "session") || strings.Contains(cmd, "plasma") ||
strings.Contains(cmd, "gnome") || strings.Contains(cmd, "xfce") ||
strings.Contains(cmd, "dbus-launch")
}
return false
}

View File

@@ -0,0 +1,37 @@
package server
import (
"fmt"
"image"
)
const maxCapturerRetries = 5
// StubCapturer is a placeholder for platforms without screen capture support.
type StubCapturer struct{}
// Width returns 0 on unsupported platforms.
func (c *StubCapturer) Width() int { return 0 }
// Height returns 0 on unsupported platforms.
func (c *StubCapturer) Height() int { return 0 }
// Capture returns an error on unsupported platforms.
func (c *StubCapturer) Capture() (*image.RGBA, error) {
return nil, fmt.Errorf("screen capture not supported on this platform")
}
// StubInputInjector is a placeholder for platforms without input injection support.
type StubInputInjector struct{}
// InjectKey is a no-op on unsupported platforms.
func (s *StubInputInjector) InjectKey(_ uint32, _ bool) {}
// InjectPointer is a no-op on unsupported platforms.
func (s *StubInputInjector) InjectPointer(_ uint8, _, _, _, _ int) {}
// SetClipboard is a no-op on unsupported platforms.
func (s *StubInputInjector) SetClipboard(_ string) {}
// GetClipboard returns empty on unsupported platforms.
func (s *StubInputInjector) GetClipboard() string { return "" }

View File

@@ -0,0 +1,19 @@
//go:build windows
package server
import "unsafe"
// swizzleBGRAtoRGBA swaps B and R channels in a BGRA pixel buffer in-place.
// Operates on uint32 words for throughput: one read-modify-write per pixel.
func swizzleBGRAtoRGBA(pix []byte) {
n := len(pix) / 4
pixels := unsafe.Slice((*uint32)(unsafe.Pointer(&pix[0])), n)
for i := range n {
p := pixels[i]
// p = 0xAABBGGRR (little-endian BGRA in memory: B,G,R,A bytes)
// We want 0xAABBGGRR -> 0xAARRGGBB (RGBA in memory: R,G,B,A bytes)
// Swap byte 0 (B) and byte 2 (R), keep byte 1 (G) and byte 3 (A).
pixels[i] = (p & 0xFF00FF00) | ((p & 0x00FF0000) >> 16) | ((p & 0x000000FF) << 16)
}
}

View File

@@ -0,0 +1,634 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
// VirtualSession manages a virtual X11 display (Xvfb) with a desktop session
// running as a target user. It implements ScreenCapturer and InputInjector by
// delegating to an X11Capturer/X11InputInjector pointed at the virtual display.
const sessionIdleTimeout = 5 * time.Minute
type VirtualSession struct {
mu sync.Mutex
display string
user *user.User
uid uint32
gid uint32
groups []uint32
xvfb *exec.Cmd
desktop *exec.Cmd
poller *X11Poller
injector *X11InputInjector
log *log.Entry
stopped bool
clients int
idleTimer *time.Timer
onIdle func() // called when idle timeout fires or Xvfb dies
}
// StartVirtualSession creates and starts a virtual X11 session for the given user.
// Requires root privileges to create sessions as other users.
func StartVirtualSession(username string, logger *log.Entry) (*VirtualSession, error) {
if os.Getuid() != 0 {
return nil, fmt.Errorf("virtual sessions require root privileges")
}
if _, err := exec.LookPath("Xvfb"); err != nil {
if _, err := exec.LookPath("Xorg"); err != nil {
return nil, fmt.Errorf("neither Xvfb nor Xorg found (install xvfb or xserver-xorg)")
}
if !hasDummyDriver() {
return nil, fmt.Errorf("Xvfb not found and Xorg dummy driver not installed (install xvfb or xf86-video-dummy)")
}
}
u, err := user.Lookup(username)
if err != nil {
return nil, fmt.Errorf("lookup user %s: %w", username, err)
}
uid, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
return nil, fmt.Errorf("parse uid: %w", err)
}
gid, err := strconv.ParseUint(u.Gid, 10, 32)
if err != nil {
return nil, fmt.Errorf("parse gid: %w", err)
}
groups, err := supplementaryGroups(u)
if err != nil {
logger.Debugf("supplementary groups for %s: %v", username, err)
}
vs := &VirtualSession{
user: u,
uid: uint32(uid),
gid: uint32(gid),
groups: groups,
log: logger.WithField("vnc_user", username),
}
if err := vs.start(); err != nil {
return nil, err
}
return vs, nil
}
func (vs *VirtualSession) start() error {
display, err := findFreeDisplay()
if err != nil {
return fmt.Errorf("find free display: %w", err)
}
vs.display = display
if err := vs.startXvfb(); err != nil {
return err
}
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", vs.display[1:])
if err := waitForPath(socketPath, 5*time.Second); err != nil {
vs.stopXvfb()
return fmt.Errorf("wait for X11 socket %s: %w", socketPath, err)
}
// Grant the target user access to the display via xhost.
xhostCmd := exec.Command("xhost", "+SI:localuser:"+vs.user.Username)
xhostCmd.Env = []string{"DISPLAY=" + vs.display}
if out, err := xhostCmd.CombinedOutput(); err != nil {
vs.log.Debugf("xhost: %s (%v)", strings.TrimSpace(string(out)), err)
}
vs.poller = NewX11Poller(vs.display)
injector, err := NewX11InputInjector(vs.display)
if err != nil {
vs.stopXvfb()
return fmt.Errorf("create X11 injector for %s: %w", vs.display, err)
}
vs.injector = injector
if err := vs.startDesktop(); err != nil {
vs.injector.Close()
vs.stopXvfb()
return fmt.Errorf("start desktop: %w", err)
}
vs.log.Infof("virtual session started: display=%s user=%s", vs.display, vs.user.Username)
return nil
}
// ClientConnect increments the client count and cancels any idle timer.
func (vs *VirtualSession) ClientConnect() {
vs.mu.Lock()
defer vs.mu.Unlock()
vs.clients++
if vs.idleTimer != nil {
vs.idleTimer.Stop()
vs.idleTimer = nil
}
}
// ClientDisconnect decrements the client count. When the last client
// disconnects, starts an idle timer that destroys the session.
func (vs *VirtualSession) ClientDisconnect() {
vs.mu.Lock()
defer vs.mu.Unlock()
vs.clients--
if vs.clients <= 0 {
vs.clients = 0
vs.log.Infof("no VNC clients connected, session will be destroyed in %s", sessionIdleTimeout)
vs.idleTimer = time.AfterFunc(sessionIdleTimeout, vs.idleExpired)
}
}
// idleExpired is called by the idle timer. It stops the session and
// notifies the session manager via onIdle so it removes us from the map.
func (vs *VirtualSession) idleExpired() {
vs.log.Info("idle timeout reached, destroying virtual session")
vs.Stop()
// onIdle acquires sessionManager.mu; safe because Stop() has released vs.mu.
if vs.onIdle != nil {
vs.onIdle()
}
}
// isAlive returns true if the session is running and its X server socket exists.
func (vs *VirtualSession) isAlive() bool {
vs.mu.Lock()
stopped := vs.stopped
display := vs.display
vs.mu.Unlock()
if stopped {
return false
}
// Verify the X socket still exists on disk.
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", display[1:])
if _, err := os.Stat(socketPath); err != nil {
return false
}
return true
}
// Capturer returns the screen capturer for this virtual session.
func (vs *VirtualSession) Capturer() ScreenCapturer {
return vs.poller
}
// Injector returns the input injector for this virtual session.
func (vs *VirtualSession) Injector() InputInjector {
return vs.injector
}
// Display returns the X11 display string (e.g., ":99").
func (vs *VirtualSession) Display() string {
return vs.display
}
// Stop terminates the virtual session, killing the desktop and Xvfb.
func (vs *VirtualSession) Stop() {
vs.mu.Lock()
defer vs.mu.Unlock()
if vs.stopped {
return
}
vs.stopped = true
if vs.injector != nil {
vs.injector.Close()
}
vs.stopDesktop()
vs.stopXvfb()
vs.log.Info("virtual session stopped")
}
func (vs *VirtualSession) startXvfb() error {
if _, err := exec.LookPath("Xvfb"); err == nil {
return vs.startXvfbDirect()
}
return vs.startXorgDummy()
}
func (vs *VirtualSession) startXvfbDirect() error {
vs.xvfb = exec.Command("Xvfb", vs.display,
"-screen", "0", "1280x800x24",
"-ac",
"-nolisten", "tcp",
)
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
if err := vs.xvfb.Start(); err != nil {
return fmt.Errorf("start Xvfb on %s: %w", vs.display, err)
}
vs.log.Infof("Xvfb started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
go vs.monitorXvfb()
return nil
}
// startXorgDummy starts Xorg with the dummy video driver as a fallback when
// Xvfb is not installed. Most systems with a desktop have Xorg available.
func (vs *VirtualSession) startXorgDummy() error {
confPath := fmt.Sprintf("/tmp/nbvnc-dummy-%s.conf", vs.display[1:])
conf := `Section "Device"
Identifier "dummy"
Driver "dummy"
VideoRam 256000
EndSection
Section "Screen"
Identifier "screen"
Device "dummy"
DefaultDepth 24
SubSection "Display"
Depth 24
Modes "1280x800"
EndSubSection
EndSection
`
if err := os.WriteFile(confPath, []byte(conf), 0644); err != nil {
return fmt.Errorf("write Xorg dummy config: %w", err)
}
vs.xvfb = exec.Command("Xorg", vs.display,
"-config", confPath,
"-noreset",
"-nolisten", "tcp",
"-ac",
)
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
if err := vs.xvfb.Start(); err != nil {
os.Remove(confPath)
return fmt.Errorf("start Xorg dummy on %s: %w", vs.display, err)
}
vs.log.Infof("Xorg (dummy driver) started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
go func() {
vs.monitorXvfb()
os.Remove(confPath)
}()
return nil
}
// monitorXvfb waits for the Xvfb/Xorg process to exit. If it exits
// unexpectedly (not via Stop), the session is marked as dead and the
// onIdle callback fires so the session manager removes it from the map.
// The next GetOrCreate call for this user will create a fresh session.
func (vs *VirtualSession) monitorXvfb() {
if err := vs.xvfb.Wait(); err != nil {
vs.log.Debugf("X server exited: %v", err)
}
vs.mu.Lock()
alreadyStopped := vs.stopped
if !alreadyStopped {
vs.log.Warn("X server exited unexpectedly, marking session as dead")
vs.stopped = true
if vs.idleTimer != nil {
vs.idleTimer.Stop()
vs.idleTimer = nil
}
if vs.injector != nil {
vs.injector.Close()
}
vs.stopDesktop()
}
onIdle := vs.onIdle
vs.mu.Unlock()
if !alreadyStopped && onIdle != nil {
onIdle()
}
}
func (vs *VirtualSession) stopXvfb() {
if vs.xvfb != nil && vs.xvfb.Process != nil {
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGTERM)
time.Sleep(200 * time.Millisecond)
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGKILL)
}
}
func (vs *VirtualSession) startDesktop() error {
session := detectDesktopSession()
// Wrap the desktop command with dbus-launch to provide a session bus.
// Without this, most desktop environments (XFCE, MATE, etc.) fail immediately.
var args []string
if _, err := exec.LookPath("dbus-launch"); err == nil {
args = append([]string{"dbus-launch", "--exit-with-session"}, session...)
} else {
args = session
}
vs.desktop = exec.Command(args[0], args[1:]...)
vs.desktop.Dir = vs.user.HomeDir
vs.desktop.Env = vs.buildUserEnv()
vs.desktop.SysProcAttr = &syscall.SysProcAttr{
Credential: &syscall.Credential{
Uid: vs.uid,
Gid: vs.gid,
Groups: vs.groups,
},
Setsid: true,
Pdeathsig: syscall.SIGTERM,
}
if err := vs.desktop.Start(); err != nil {
return fmt.Errorf("start desktop session (%v): %w", args, err)
}
vs.log.Infof("desktop session started: %v (pid=%d)", args, vs.desktop.Process.Pid)
go func() {
if err := vs.desktop.Wait(); err != nil {
vs.log.Debugf("desktop session exited: %v", err)
}
}()
return nil
}
func (vs *VirtualSession) stopDesktop() {
if vs.desktop != nil && vs.desktop.Process != nil {
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGTERM)
time.Sleep(200 * time.Millisecond)
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGKILL)
}
}
func (vs *VirtualSession) buildUserEnv() []string {
return []string{
"DISPLAY=" + vs.display,
"HOME=" + vs.user.HomeDir,
"USER=" + vs.user.Username,
"LOGNAME=" + vs.user.Username,
"SHELL=" + getUserShell(vs.user.Uid),
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
"XDG_RUNTIME_DIR=/run/user/" + vs.user.Uid,
"DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/" + vs.user.Uid + "/bus",
}
}
// detectDesktopSession discovers available desktop sessions from the standard
// /usr/share/xsessions/*.desktop files (FreeDesktop standard, used by all
// display managers). Falls back to a hardcoded list if no .desktop files found.
func detectDesktopSession() []string {
// Scan xsessions directories (Linux: /usr/share, FreeBSD: /usr/local/share).
for _, dir := range []string{"/usr/share/xsessions", "/usr/local/share/xsessions"} {
if cmd := findXSession(dir); cmd != nil {
return cmd
}
}
// Fallback: try common session commands directly.
fallbacks := [][]string{
{"startplasma-x11"},
{"gnome-session"},
{"xfce4-session"},
{"mate-session"},
{"cinnamon-session"},
{"openbox-session"},
{"xterm"},
}
for _, s := range fallbacks {
if _, err := exec.LookPath(s[0]); err == nil {
return s
}
}
return []string{"xterm"}
}
// sessionPriority defines preference order for desktop environments.
// Lower number = higher priority. Unknown sessions get 100.
var sessionPriority = map[string]int{
"plasma": 1, // KDE
"gnome": 2,
"xfce": 3,
"mate": 4,
"cinnamon": 5,
"lxqt": 6,
"lxde": 7,
"budgie": 8,
"openbox": 20,
"fluxbox": 21,
"i3": 22,
"xinit": 50, // generic user session
"lightdm": 50,
"default": 50,
}
func findXSession(dir string) []string {
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
type candidate struct {
cmd string
priority int
}
var candidates []candidate
for _, e := range entries {
if !strings.HasSuffix(e.Name(), ".desktop") {
continue
}
data, err := os.ReadFile(filepath.Join(dir, e.Name()))
if err != nil {
continue
}
execCmd := ""
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "Exec=") {
execCmd = strings.TrimSpace(strings.TrimPrefix(line, "Exec="))
break
}
}
if execCmd == "" || execCmd == "default" {
continue
}
// Determine priority from the filename or exec command.
pri := 100
lower := strings.ToLower(e.Name() + " " + execCmd)
for keyword, p := range sessionPriority {
if strings.Contains(lower, keyword) && p < pri {
pri = p
}
}
candidates = append(candidates, candidate{cmd: execCmd, priority: pri})
}
if len(candidates) == 0 {
return nil
}
// Pick the highest priority (lowest number).
best := candidates[0]
for _, c := range candidates[1:] {
if c.priority < best.priority {
best = c
}
}
// Verify the binary exists.
parts := strings.Fields(best.cmd)
if _, err := exec.LookPath(parts[0]); err != nil {
return nil
}
return parts
}
// findFreeDisplay scans for an unused X11 display number.
func findFreeDisplay() (string, error) {
for n := 50; n < 200; n++ {
lockFile := fmt.Sprintf("/tmp/.X%d-lock", n)
socketFile := fmt.Sprintf("/tmp/.X11-unix/X%d", n)
if _, err := os.Stat(lockFile); err == nil {
continue
}
if _, err := os.Stat(socketFile); err == nil {
continue
}
return fmt.Sprintf(":%d", n), nil
}
return "", fmt.Errorf("no free X11 display found (checked :50-:199)")
}
// waitForPath polls until a filesystem path exists or the timeout expires.
func waitForPath(path string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if _, err := os.Stat(path); err == nil {
return nil
}
time.Sleep(50 * time.Millisecond)
}
return fmt.Errorf("timeout waiting for %s", path)
}
// getUserShell returns the login shell for the given UID.
func getUserShell(uid string) string {
data, err := os.ReadFile("/etc/passwd")
if err != nil {
return "/bin/sh"
}
for _, line := range strings.Split(string(data), "\n") {
fields := strings.Split(line, ":")
if len(fields) >= 7 && fields[2] == uid {
return fields[6]
}
}
return "/bin/sh"
}
// supplementaryGroups returns the supplementary group IDs for a user.
func supplementaryGroups(u *user.User) ([]uint32, error) {
gids, err := u.GroupIds()
if err != nil {
return nil, err
}
var groups []uint32
for _, g := range gids {
id, err := strconv.ParseUint(g, 10, 32)
if err != nil {
continue
}
groups = append(groups, uint32(id))
}
return groups, nil
}
// sessionManager tracks active virtual sessions by username.
type sessionManager struct {
mu sync.Mutex
sessions map[string]*VirtualSession
log *log.Entry
}
func newSessionManager(logger *log.Entry) *sessionManager {
return &sessionManager{
sessions: make(map[string]*VirtualSession),
log: logger,
}
}
// GetOrCreate returns an existing virtual session or creates a new one.
// If a previous session for this user is stopped or its X server died, it is replaced.
func (sm *sessionManager) GetOrCreate(username string) (vncSession, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
if vs, ok := sm.sessions[username]; ok {
if vs.isAlive() {
return vs, nil
}
sm.log.Infof("replacing dead virtual session for %s", username)
vs.Stop()
delete(sm.sessions, username)
}
vs, err := StartVirtualSession(username, sm.log)
if err != nil {
return nil, err
}
vs.onIdle = func() {
sm.mu.Lock()
defer sm.mu.Unlock()
if cur, ok := sm.sessions[username]; ok && cur == vs {
delete(sm.sessions, username)
sm.log.Infof("removed idle virtual session for %s", username)
}
}
sm.sessions[username] = vs
return vs, nil
}
// hasDummyDriver checks common paths for the Xorg dummy video driver.
func hasDummyDriver() bool {
paths := []string{
"/usr/lib/xorg/modules/drivers/dummy_drv.so", // Debian/Ubuntu
"/usr/lib64/xorg/modules/drivers/dummy_drv.so", // RHEL/Fedora
"/usr/local/lib/xorg/modules/drivers/dummy_drv.so", // FreeBSD
"/usr/lib/x86_64-linux-gnu/xorg/modules/drivers/dummy_drv.so", // Debian multiarch
}
for _, p := range paths {
if _, err := os.Stat(p); err == nil {
return true
}
}
return false
}
// StopAll terminates all active virtual sessions.
func (sm *sessionManager) StopAll() {
sm.mu.Lock()
defer sm.mu.Unlock()
for username, vs := range sm.sessions {
vs.Stop()
delete(sm.sessions, username)
sm.log.Infof("stopped virtual session for %s", username)
}
}

View File

@@ -0,0 +1,174 @@
<!DOCTYPE html>
<html>
<head>
<title>VNC Test</title>
<style>
body { margin: 0; background: #111; color: #eee; font-family: monospace; font-size: 13px; }
#toolbar { position: fixed; top: 0; left: 0; right: 0; z-index: 10; background: #222; padding: 4px 8px; display: flex; gap: 8px; align-items: center; }
#toolbar button { padding: 4px 12px; cursor: pointer; background: #444; color: #eee; border: 1px solid #666; border-radius: 3px; }
#toolbar button:hover { background: #555; }
#toolbar #status { flex: 1; }
#vnc-container { width: 100vw; height: calc(100vh - 28px); margin-top: 28px; }
#log { position: fixed; bottom: 0; left: 0; right: 0; max-height: 150px; overflow-y: auto; background: rgba(0,0,0,0.85); padding: 4px 8px; font-size: 11px; z-index: 10; display: none; }
#log.visible { display: block; }
</style>
</head>
<body>
<div id="toolbar">
<span id="status">Loading WASM...</span>
<button onclick="sendCAD()">Ctrl+Alt+Del</button>
<button onclick="document.getElementById('log').classList.toggle('visible')">Log</button>
</div>
<div id="vnc-container"></div>
<div id="log"></div>
<script>
const params = new URLSearchParams(location.search);
const HOST = params.get('host') || '';
const PORT = params.get('port') || '5900';
const MODE = params.get('mode') || 'attach'; // 'attach' or 'session'
const USER = params.get('user') || '';
const SETUP_KEY = params.get('setup_key') || '64BB8FF4-5A96-488F-B0AE-316555E916B0';
const MGMT_URL = params.get('mgmt') || 'http://192.168.100.1:8080';
const statusEl = document.getElementById('status');
const logEl = document.getElementById('log');
function addLog(msg) {
const line = document.createElement('div');
line.textContent = `[${new Date().toISOString().slice(11,23)}] ${msg}`;
logEl.appendChild(line);
logEl.scrollTop = logEl.scrollHeight;
console.log('[vnc-test]', msg);
}
function setStatus(s) { statusEl.textContent = s; addLog(s); }
let rfbInstance = null;
window.sendCAD = () => { if (rfbInstance) { rfbInstance.sendCtrlAltDel(); addLog('Sent Ctrl+Alt+Del'); } };
// VNC WebSocket proxy (bridges noVNC WebSocket API to Go WASM tunnel)
class VNCProxyWS extends EventTarget {
constructor(url) {
super();
this.url = url;
this.readyState = 0;
this.protocol = '';
this.extensions = '';
this.bufferedAmount = 0;
this.binaryType = 'arraybuffer';
this.onopen = null; this.onclose = null; this.onerror = null; this.onmessage = null;
const match = url.match(/vnc\.proxy\.local\/(.+)/);
this._proxyID = match ? match[1] : 'default';
setTimeout(() => this._connect(), 0);
}
get CONNECTING() { return 0; } get OPEN() { return 1; } get CLOSING() { return 2; } get CLOSED() { return 3; }
_connect() {
try {
const handler = window[`handleVNCWebSocket_${this._proxyID}`];
if (!handler) throw new Error(`No VNC handler for ${this._proxyID}`);
handler(this);
this.readyState = 1;
const ev = new Event('open');
if (this.onopen) this.onopen(ev);
this.dispatchEvent(ev);
} catch (err) {
addLog(`WS proxy error: ${err.message}`);
this.readyState = 3;
}
}
receiveFromGo(data) {
const ev = new MessageEvent('message', { data });
if (this.onmessage) this.onmessage(ev);
this.dispatchEvent(ev);
}
send(data) {
if (this.readyState !== 1) return;
let u8;
if (data instanceof ArrayBuffer) u8 = new Uint8Array(data);
else if (data instanceof Uint8Array) u8 = data;
else if (typeof data === 'string') u8 = new TextEncoder().encode(data);
else if (data.buffer) u8 = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
else return;
if (this.onGoMessage) this.onGoMessage(u8);
}
close(code, reason) {
if (this.readyState >= 2) return;
this.readyState = 2;
if (this.onGoClose) this.onGoClose();
setTimeout(() => {
this.readyState = 3;
const ev = new CloseEvent('close', { code: code||1000, reason: reason||'', wasClean: true });
if (this.onclose) this.onclose(ev);
this.dispatchEvent(ev);
}, 0);
}
}
async function main() {
if (!HOST) { setStatus('Usage: ?host=<peer_ip>&setup_key=<key>[&mode=session&user=alice]'); return; }
// Install WS proxy before anything creates WebSockets
const OrigWS = window.WebSocket;
window.WebSocket = new Proxy(OrigWS, {
construct(target, args) {
if (args[0] && args[0].includes('vnc.proxy.local')) return new VNCProxyWS(args[0]);
return new target(args[0], args[1]);
}
});
// Load WASM
setStatus('Loading WASM runtime...');
await new Promise((resolve, reject) => {
const s = document.createElement('script');
s.src = '/wasm_exec.js'; s.onload = resolve; s.onerror = reject;
document.head.appendChild(s);
});
setStatus('Loading NetBird WASM...');
const go = new Go();
const wasm = await WebAssembly.instantiateStreaming(fetch('/netbird.wasm'), go.importObject);
go.run(wasm.instance);
const t0 = Date.now();
while (!window.NetBirdClient && Date.now() - t0 < 10000) await new Promise(r => setTimeout(r, 100));
if (!window.NetBirdClient) { setStatus('WASM init timeout'); return; }
addLog('WASM ready');
// Connect NetBird with setup key
setStatus('Connecting NetBird...');
let client;
try {
client = await window.NetBirdClient({
setupKey: SETUP_KEY,
managementURL: MGMT_URL,
logLevel: 'debug',
});
addLog('Client created, starting...');
await client.start();
addLog('NetBird connected');
} catch (err) {
setStatus('NetBird error: ' + (err && err.message ? err.message : String(err)));
return;
}
// Create VNC proxy
setStatus(`Creating VNC proxy (mode=${MODE}${USER ? ', user=' + USER : ''})...`);
const proxyURL = await client.createVNCProxy(HOST, PORT, MODE, USER);
addLog(`Proxy: ${proxyURL}`);
// Connect noVNC
setStatus('Connecting VNC...');
const { default: RFB } = await import('/novnc-pkg/core/rfb.js');
const container = document.getElementById('vnc-container');
rfbInstance = new RFB(container, proxyURL, { wsProtocols: [] });
rfbInstance.scaleViewport = true;
rfbInstance.resizeSession = false;
rfbInstance.showDotCursor = true;
rfbInstance.addEventListener('connect', () => setStatus(`Connected: ${HOST}`));
rfbInstance.addEventListener('disconnect', e => setStatus(`Disconnected${e.detail?.clean ? '' : ' (unexpected)'}`));
rfbInstance.addEventListener('credentialsrequired', () => rfbInstance.sendCredentials({ password: '' }));
window.rfb = rfbInstance;
}
main().catch(err => { setStatus(`Error: ${err.message}`); console.error(err); });
</script>
</body>
</html>

View File

@@ -0,0 +1,44 @@
//go:build ignore
// Simple file server for the VNC test page.
// Usage: go run serve.go
// Then open: http://localhost:9090?host=100.0.23.250
package main
import (
"fmt"
"net/http"
"os"
"path/filepath"
)
func main() {
// Serve from the dashboard's public dir (has wasm, noVNC, etc.)
dashboardPublic := os.Getenv("DASHBOARD_PUBLIC")
if dashboardPublic == "" {
home, _ := os.UserHomeDir()
dashboardPublic = filepath.Join(home, "dev", "dashboard", "public")
}
// Serve test page index.html from this directory
testDir, _ := os.Getwd()
mux := http.NewServeMux()
// Test page itself
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
http.ServeFile(w, r, filepath.Join(testDir, "index.html"))
return
}
// Everything else from dashboard public (wasm, noVNC, etc.)
http.FileServer(http.Dir(dashboardPublic)).ServeHTTP(w, r)
})
addr := ":9090"
fmt.Printf("VNC test page: http://localhost%s?host=<peer_ip>\n", addr)
fmt.Printf("Serving assets from: %s\n", dashboardPublic)
if err := http.ListenAndServe(addr, mux); err != nil {
fmt.Fprintf(os.Stderr, "listen: %v\n", err)
os.Exit(1)
}
}

View File

@@ -15,8 +15,8 @@ import (
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/client/wasm/internal/vnc"
"github.com/netbirdio/netbird/util"
)
@@ -317,8 +317,13 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
})
}
// createRDPProxyMethod creates the RDP proxy method
func createRDPProxyMethod(client *netbird.Client) js.Func {
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
// JS signature: createVNCProxy(hostname, port, mode?, username?, jwt?, sessionID?)
// mode: "attach" (default) or "session"
// username: required when mode is "session"
// jwt: authentication token (from OIDC session)
// sessionID: Windows session ID (0 = console/auto)
func createVNCProxyMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
@@ -335,8 +340,25 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
})
}
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
mode := "attach"
username := ""
jwtToken := ""
var sessionID uint32
if len(args) > 2 && args[2].Type() == js.TypeString {
mode = args[2].String()
}
if len(args) > 3 && args[3].Type() == js.TypeString {
username = args[3].String()
}
if len(args) > 4 && args[4].Type() == js.TypeString {
jwtToken = args[4].String()
}
if len(args) > 5 && args[5].Type() == js.TypeNumber {
sessionID = uint32(args[5].Int())
}
proxy := vnc.NewVNCProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String(), mode, username, jwtToken, sessionID)
})
}
@@ -515,7 +537,7 @@ func createClientObject(client *netbird.Client) js.Value {
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["createVNCProxy"] = createVNCProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)

View File

@@ -1,107 +0,0 @@
//go:build js
package rdp
import (
"crypto/tls"
"crypto/x509"
"fmt"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
certValidationTimeout = 60 * time.Second
)
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
return false, fmt.Errorf("certificate validation handler not configured")
}
certInfo := js.Global().Get("Object").New()
certInfo.Set("ServerAddr", conn.destination)
certArray := js.Global().Get("Array").New()
for i, certBytes := range certChain {
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
js.CopyBytesToJS(uint8Array, certBytes)
certArray.SetIndex(i, uint8Array)
}
certInfo.Set("ServerCertChain", certArray)
if len(certChain) > 0 {
cert, err := x509.ParseCertificate(certChain[0])
if err == nil {
info := js.Global().Get("Object").New()
info.Set("subject", cert.Subject.String())
info.Set("issuer", cert.Issuer.String())
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
info.Set("serialNumber", cert.SerialNumber.String())
certInfo.Set("CertificateInfo", info)
}
}
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
resultChan := make(chan bool)
errorChan := make(chan error)
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
result := args[0].Bool()
resultChan <- result
return nil
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
errorChan <- fmt.Errorf("certificate validation failed")
return nil
}))
select {
case result := <-resultChan:
if result {
log.Info("Certificate accepted by user")
} else {
log.Info("Certificate rejected by user")
}
return result, nil
case err := <-errorChan:
return false, err
case <-time.After(certValidationTimeout):
return false, fmt.Errorf("certificate validation timeout")
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
for _, cert := range cs.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
accepted, err := p.validateCertificateWithJS(conn, certChain)
if err != nil {
return err
}
if !accepted {
return fmt.Errorf("certificate rejected by user")
}
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -1,344 +0,0 @@
//go:build js
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
activeConnections map[string]*proxyConnection
destinations map[string]string
mu sync.Mutex
}
type proxyConnection struct {
id string
destination string
rdpConn net.Conn
tlsConn *tls.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
}
// NewRDCleanPathProxy creates a new RDCleanPath proxy
func NewRDCleanPathProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *RDCleanPathProxy {
return &RDCleanPathProxy{
nbClient: client,
activeConnections: make(map[string]*proxyConnection),
}
}
// CreateProxy creates a new proxy endpoint for the given destination
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
destination := fmt.Sprintf("%s:%s", hostname, port)
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
p.mu.Lock()
if p.destinations == nil {
p.destinations = make(map[string]string)
}
p.destinations[proxyID] = destination
p.mu.Unlock()
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
// Register the WebSocket handler for this specific proxy
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
ws := args[0]
p.HandleWebSocketConnection(ws, proxyID)
return nil
}))
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
resolve.Invoke(proxyURL)
}()
return nil
}))
}
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
p.mu.Lock()
destination := p.destinations[proxyID]
p.mu.Unlock()
if destination == "" {
log.Errorf("No destination found for proxy ID: %s", proxyID)
return
}
ctx, cancel := context.WithCancel(context.Background())
// Don't defer cancel here - it will be called by cleanupConnection
conn := &proxyConnection{
id: proxyID,
destination: destination,
wsHandlers: ws,
ctx: ctx,
cancel: cancel,
}
p.mu.Lock()
p.activeConnections[proxyID] = conn
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
}
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
}))
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Debug("WebSocket closed by JavaScript")
conn.cancel()
return nil
}))
}
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
return
}
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
if conn.rdpConn != nil || conn.tlsConn != nil {
p.forwardToRDP(conn, bytes)
return
}
var pdu RDCleanPathPDU
_, err := asn1.Unmarshal(bytes, &pdu)
if err != nil {
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
n := len(bytes)
if n > 20 {
n = 20
}
log.Warnf("First %d bytes: %x", n, bytes[:n])
if len(bytes) > 0 && bytes[0] == 0x03 {
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
go p.handleDirectRDP(conn, bytes)
return
}
return
}
go p.processRDCleanPathPDU(conn, pdu)
}
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
var writer io.Writer
var connType string
if conn.tlsConn != nil {
writer = conn.tlsConn
connType = "TLS"
} else if conn.rdpConn != nil {
writer = conn.rdpConn
connType = "TCP"
} else {
log.Error("No RDP connection available")
return
}
if _, err := writer.Write(bytes); err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
}
}
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
defer p.cleanupConnection(conn)
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
response := make([]byte, 1024)
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
p.sendToWebSocket(conn, response[:n])
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
}
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
log.Debugf("Cleaning up connection %s", conn.id)
conn.cancel()
if conn.tlsConn != nil {
log.Debug("Closing TLS connection")
if err := conn.tlsConn.Close(); err != nil {
log.Debugf("Error closing TLS connection: %v", err)
}
conn.tlsConn = nil
}
if conn.rdpConn != nil {
log.Debug("Closing TCP connection")
if err := conn.rdpConn.Close(); err != nil {
log.Debugf("Error closing TCP connection: %v", err)
}
conn.rdpConn = nil
}
p.mu.Lock()
delete(p.activeConnections, conn.id)
p.mu.Unlock()
}
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
} else if conn.wsHandlers.Get("send").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -1,244 +0,0 @@
//go:build js
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
"syscall/js"
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
destination := conn.destination
if pdu.Destination != "" {
destination = pdu.Destination
}
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
conn.rdpConn = rdpConn
// RDP always starts with X.224 negotiation, then determines if TLS is needed
// Modern RDP (since Windows Vista/2008) typically requires TLS
// The X.224 Connection Confirm response will indicate if TLS is required
// For now, we'll attempt TLS for all connections as it's the modern default
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
log.Info("TLS handshake successful")
// Certificate validation happens during handshake via VerifyConnection callback
var certChain [][]byte
connState := tlsConn.ConnectionState()
if len(connState.PeerCertificates) > 0 {
for _, cert := range connState.PeerCertificates {
certChain = append(certChain, cert.Raw)
}
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
ServerCertChain: certChain,
}
if len(x224Response) > 0 {
responsePDU.X224ConnectionPDU = x224Response
}
p.sendRDCleanPathPDU(conn, responsePDU)
log.Debug("Starting TLS forwarding")
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
<-conn.ctx.Done()
log.Debug("TLS connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
return
}
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
if len(args) < 1 {
errChan <- io.EOF
return nil
}
data := args[0]
if data.InstanceOf(js.Global().Get("Uint8Array")) {
length := data.Get("length").Int()
bytes := make([]byte, length)
js.CopyBytesToGo(bytes, data)
msgChan <- bytes
}
return nil
})
defer handler.Release()
conn.wsHandlers.Set("onceGoMessage", handler)
select {
case msg := <-msgChan:
return msg, nil
case err := <-errChan:
return nil, err
case <-conn.ctx.Done():
return nil, conn.ctx.Err()
}
}
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
for {
if conn.ctx.Err() != nil {
return
}
msg, err := p.readWebSocketMessage(conn)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from WebSocket: %v", err)
}
return
}
_, err = dst.Write(msg)
if err != nil {
log.Errorf("Failed to write to %s: %v", connType, err)
return
}
}
}
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
buffer := make([]byte, 32*1024)
for {
if conn.ctx.Err() != nil {
return
}
n, err := src.Read(buffer)
if err != nil {
if err != io.EOF {
log.Errorf("Failed to read from %s: %v", connType, err)
}
return
}
if n > 0 {
p.sendToWebSocket(conn, buffer[:n])
}
}
}

View File

@@ -0,0 +1,332 @@
//go:build js
package vnc
import (
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
const (
vncProxyHost = "vnc.proxy.local"
vncProxyScheme = "ws"
vncDialTimeout = 15 * time.Second
// Connection modes matching server/server.go constants.
modeAttach byte = 0
modeSession byte = 1
)
// VNCProxy bridges WebSocket connections from noVNC in the browser
// to TCP VNC server connections through the NetBird tunnel.
type VNCProxy struct {
nbClient interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}
activeConnections map[string]*vncConnection
destinations map[string]vncDestination
mu sync.Mutex
nextID atomic.Uint64
}
type vncDestination struct {
address string
mode byte
username string
jwt string
sessionID uint32 // Windows session ID (0 = auto/console)
}
type vncConnection struct {
id string
destination vncDestination
mu sync.Mutex
vncConn net.Conn
wsHandlers js.Value
ctx context.Context
cancel context.CancelFunc
}
// NewVNCProxy creates a new VNC proxy.
func NewVNCProxy(client interface {
Dial(ctx context.Context, network, address string) (net.Conn, error)
}) *VNCProxy {
return &VNCProxy{
nbClient: client,
activeConnections: make(map[string]*vncConnection),
}
}
// CreateProxy creates a new proxy endpoint for the given VNC destination.
// mode is "attach" (capture current display) or "session" (virtual session).
// username is required for session mode.
// Returns a JS Promise that resolves to the WebSocket proxy URL.
func (p *VNCProxy) CreateProxy(hostname, port, mode, username, jwt string, sessionID uint32) js.Value {
address := fmt.Sprintf("%s:%s", hostname, port)
var m byte
if mode == "session" {
m = modeSession
}
dest := vncDestination{
address: address,
mode: m,
username: username,
jwt: jwt,
sessionID: sessionID,
}
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
resolve := args[0]
go func() {
proxyID := fmt.Sprintf("vnc_proxy_%d", p.nextID.Add(1))
p.mu.Lock()
if p.destinations == nil {
p.destinations = make(map[string]vncDestination)
}
p.destinations[proxyID] = dest
p.mu.Unlock()
proxyURL := fmt.Sprintf("%s://%s/%s", vncProxyScheme, vncProxyHost, proxyID)
js.Global().Set(fmt.Sprintf("handleVNCWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: requires WebSocket argument")
}
ws := args[0]
p.handleWebSocketConnection(ws, proxyID)
return nil
}))
log.Infof("created VNC proxy: %s -> %s (mode=%s, user=%s)", proxyURL, address, mode, username)
resolve.Invoke(proxyURL)
}()
return nil
}))
}
func (p *VNCProxy) handleWebSocketConnection(ws js.Value, proxyID string) {
p.mu.Lock()
dest, ok := p.destinations[proxyID]
p.mu.Unlock()
if !ok {
log.Errorf("no destination for VNC proxy %s", proxyID)
return
}
ctx, cancel := context.WithCancel(context.Background())
conn := &vncConnection{
id: proxyID,
destination: dest,
wsHandlers: ws,
ctx: ctx,
cancel: cancel,
}
p.mu.Lock()
p.activeConnections[proxyID] = conn
p.mu.Unlock()
p.setupWebSocketHandlers(ws, conn)
go p.connectToVNC(conn)
log.Infof("VNC proxy WebSocket connection established for %s", proxyID)
}
func (p *VNCProxy) setupWebSocketHandlers(ws js.Value, conn *vncConnection) {
ws.Set("onGoMessage", js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return nil
}
data := args[0]
go p.handleWebSocketMessage(conn, data)
return nil
}))
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, _ []js.Value) any {
log.Debug("VNC WebSocket closed by JavaScript")
conn.cancel()
return nil
}))
}
func (p *VNCProxy) handleWebSocketMessage(conn *vncConnection, data js.Value) {
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
return
}
length := data.Get("length").Int()
buf := make([]byte, length)
js.CopyBytesToGo(buf, data)
conn.mu.Lock()
vncConn := conn.vncConn
conn.mu.Unlock()
if vncConn == nil {
return
}
if _, err := vncConn.Write(buf); err != nil {
log.Debugf("write to VNC server: %v", err)
}
}
func (p *VNCProxy) connectToVNC(conn *vncConnection) {
ctx, cancel := context.WithTimeout(conn.ctx, vncDialTimeout)
defer cancel()
vncConn, err := p.nbClient.Dial(ctx, "tcp", conn.destination.address)
if err != nil {
log.Errorf("VNC connect to %s: %v", conn.destination.address, err)
// Close the WebSocket so noVNC fires a disconnect event.
if conn.wsHandlers.Get("close").Truthy() {
conn.wsHandlers.Call("close", 1006, fmt.Sprintf("connect to peer: %v", err))
}
p.cleanupConnection(conn)
return
}
conn.mu.Lock()
conn.vncConn = vncConn
conn.mu.Unlock()
// Send the NetBird VNC session header before the RFB handshake.
if err := p.sendSessionHeader(vncConn, conn.destination); err != nil {
log.Errorf("send VNC session header: %v", err)
p.cleanupConnection(conn)
return
}
// WS→TCP is handled by the onGoMessage handler set in setupWebSocketHandlers,
// which writes directly to the VNC connection as data arrives from JS.
// Only the TCP→WS direction needs a read loop here.
go p.forwardConnToWS(conn)
<-conn.ctx.Done()
p.cleanupConnection(conn)
}
// sendSessionHeader writes mode, username, and JWT to the VNC server.
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
//
// [jwt_len: 2 bytes BE] [jwt: N bytes]
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
usernameBytes := []byte(dest.username)
jwtBytes := []byte(dest.jwt)
// Format: [mode:1] [username_len:2] [username:N] [jwt_len:2] [jwt:N] [session_id:4]
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes)+4)
hdr[0] = dest.mode
hdr[1] = byte(len(usernameBytes) >> 8)
hdr[2] = byte(len(usernameBytes))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
hdr[off] = byte(len(jwtBytes) >> 8)
hdr[off+1] = byte(len(jwtBytes))
off += 2
copy(hdr[off:], jwtBytes)
off += len(jwtBytes)
hdr[off] = byte(dest.sessionID >> 24)
hdr[off+1] = byte(dest.sessionID >> 16)
hdr[off+2] = byte(dest.sessionID >> 8)
hdr[off+3] = byte(dest.sessionID)
_, err := conn.Write(hdr)
return err
}
func (p *VNCProxy) forwardConnToWS(conn *vncConnection) {
buf := make([]byte, 32*1024)
for {
if conn.ctx.Err() != nil {
return
}
// Set a read deadline so we detect dead connections instead of
// blocking forever when the remote peer dies.
conn.mu.Lock()
vc := conn.vncConn
conn.mu.Unlock()
if vc == nil {
return
}
vc.SetReadDeadline(time.Now().Add(30 * time.Second))
n, err := vc.Read(buf)
if err != nil {
if conn.ctx.Err() != nil {
return
}
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
// Read timeout: connection might be stale. Send a ping-like
// empty read to check. If the connection is truly dead, the
// next iteration will fail too and we'll close.
continue
}
if err != io.EOF {
log.Debugf("read from VNC connection: %v", err)
}
// Close the WebSocket to notify noVNC.
if conn.wsHandlers.Get("close").Truthy() {
conn.wsHandlers.Call("close", 1006, "VNC connection lost")
}
return
}
if n > 0 {
p.sendToWebSocket(conn, buf[:n])
}
}
}
func (p *VNCProxy) sendToWebSocket(conn *vncConnection, data []byte) {
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
} else if conn.wsHandlers.Get("send").Truthy() {
uint8Array := js.Global().Get("Uint8Array").New(len(data))
js.CopyBytesToJS(uint8Array, data)
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *VNCProxy) cleanupConnection(conn *vncConnection) {
log.Debugf("cleaning up VNC connection %s", conn.id)
conn.cancel()
conn.mu.Lock()
vncConn := conn.vncConn
conn.vncConn = nil
conn.mu.Unlock()
if vncConn != nil {
if err := vncConn.Close(); err != nil {
log.Debugf("close VNC connection: %v", err)
}
}
// Remove the global JS handler registered in CreateProxy.
globalName := fmt.Sprintf("handleVNCWebSocket_%s", conn.id)
js.Global().Delete(globalName)
p.mu.Lock()
delete(p.activeConnections, conn.id)
delete(p.destinations, conn.id)
p.mu.Unlock()
}

View File

@@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
@@ -26,11 +25,22 @@ import (
"github.com/netbirdio/netbird/util/wsproxy"
)
var ErrClientClosed = errors.New("client is closed")
// minHealthyDuration is the minimum time a stream must survive before a failure
// resets the backoff timer. Streams that fail faster are considered unhealthy and
// should not reset backoff, so that MaxElapsedTime can eventually stop retries.
const minHealthyDuration = 5 * time.Second
type GRPCClient struct {
realClient proto.FlowServiceClient
clientConn *grpc.ClientConn
stream proto.FlowService_EventsClient
streamMu sync.Mutex
target string
opts []grpc.DialOption
closed bool // prevent creating conn in the middle of the Close
receiving bool // prevent concurrent Receive calls
mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving
}
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
@@ -65,7 +75,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
)
conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...)
target := parsedURL.Host
conn, err := grpc.NewClient(target, opts...)
if err != nil {
return nil, fmt.Errorf("creating new grpc client: %w", err)
}
@@ -73,30 +84,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
return &GRPCClient{
realClient: proto.NewFlowServiceClient(conn),
clientConn: conn,
target: target,
opts: opts,
}, nil
}
func (c *GRPCClient) Close() error {
c.streamMu.Lock()
defer c.streamMu.Unlock()
c.mu.Lock()
c.closed = true
c.stream = nil
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
conn := c.clientConn
c.clientConn = nil
c.mu.Unlock()
if conn == nil {
return nil
}
if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) {
return fmt.Errorf("close client connection: %w", err)
}
return nil
}
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.mu.Lock()
stream := c.stream
c.mu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
}
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
}
return nil
}
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
c.mu.Lock()
if c.receiving {
c.mu.Unlock()
return errors.New("concurrent Receive calls are not supported")
}
c.receiving = true
c.mu.Unlock()
defer func() {
c.mu.Lock()
c.receiving = false
c.mu.Unlock()
}()
backOff := defaultBackoff(ctx, interval)
operation := func() error {
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
}
stream, err := c.establishStream(ctx)
if err != nil {
log.Errorf("failed to establish flow stream, retrying: %v", err)
return c.handleRetryableError(err, time.Time{}, backOff)
}
streamStart := time.Now()
if err := c.receive(stream, msgHandler); err != nil {
log.Errorf("receive failed: %v", err)
return fmt.Errorf("receive: %w", err)
return c.handleRetryableError(err, streamStart, backOff)
}
return nil
}
@@ -108,37 +162,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan
return nil
}
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
if c.clientConn.GetState() == connectivity.Shutdown {
return errors.New("connection to flow receiver has been shut down")
// handleRetryableError resets the backoff timer if the stream was healthy long
// enough and recreates the underlying ClientConn so that gRPC's internal
// subchannel backoff does not accumulate and compete with our own retry timer.
// A zero streamStart means the stream was never established.
func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error {
if isContextDone(err) {
return backoff.Permanent(err)
}
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
if err != nil {
return fmt.Errorf("create event stream: %w", err)
var permErr *backoff.PermanentError
if errors.As(err, &permErr) {
return err
}
err = stream.Send(&proto.FlowEvent{IsInitiator: true})
// Reset the backoff so the next retry starts with a short delay instead of
// continuing the already-elapsed timer. Only do this if the stream was healthy
// long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime.
if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration {
backOff.Reset()
}
if recreateErr := c.recreateConnection(); recreateErr != nil {
log.Errorf("recreate connection: %v", recreateErr)
return recreateErr
}
log.Infof("connection recreated, retrying stream")
return fmt.Errorf("retrying after error: %w", err)
}
func (c *GRPCClient) recreateConnection() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return backoff.Permanent(ErrClientClosed)
}
conn, err := grpc.NewClient(c.target, c.opts...)
if err != nil {
log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err)
c.mu.Unlock()
return fmt.Errorf("create new connection: %w", err)
}
old := c.clientConn
c.clientConn = conn
c.realClient = proto.NewFlowServiceClient(conn)
c.stream = nil
c.mu.Unlock()
_ = old.Close()
return nil
}
func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, backoff.Permanent(ErrClientClosed)
}
cl := c.realClient
c.mu.Unlock()
// open stream outside the lock — blocking operation
stream, err := cl.Events(ctx)
if err != nil {
return nil, fmt.Errorf("create event stream: %w", err)
}
streamReady := false
defer func() {
if !streamReady {
_ = stream.CloseSend()
}
}()
if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil {
return nil, fmt.Errorf("send initiator: %w", err)
}
if err = checkHeader(stream); err != nil {
return fmt.Errorf("check header: %w", err)
return nil, fmt.Errorf("check header: %w", err)
}
c.streamMu.Lock()
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil, backoff.Permanent(ErrClientClosed)
}
c.stream = stream
c.streamMu.Unlock()
c.mu.Unlock()
streamReady = true
return c.receive(stream, msgHandler)
return stream, nil
}
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
for {
msg, err := stream.Recv()
if err != nil {
return fmt.Errorf("receive from stream: %w", err)
return err
}
if msg.IsInitiator {
@@ -169,7 +292,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
RandomizationFactor: 0.5,
Multiplier: 1.7,
MaxInterval: interval / 2,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
@@ -178,18 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
}, ctx)
}
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
c.streamMu.Lock()
stream := c.stream
c.streamMu.Unlock()
if stream == nil {
return errors.New("stream not initialized")
func isContextDone(err error) bool {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
if err := stream.Send(event); err != nil {
return fmt.Errorf("send flow event: %w", err)
if s, ok := status.FromError(err); ok {
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
}
return nil
return false
}

View File

@@ -2,8 +2,11 @@ package client_test
import (
"context"
"encoding/binary"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
@@ -11,6 +14,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
flow "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
@@ -18,21 +23,89 @@ import (
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
listener *connTrackListener
closeStream chan struct{} // signal server to close the stream
handlerDone chan struct{} // signaled each time Events() exits
handlerStarted chan struct{} // signaled each time Events() begins
}
// connTrackListener wraps a net.Listener to track accepted connections
// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM.
type connTrackListener struct {
net.Listener
mu sync.Mutex
conns []net.Conn
}
func (l *connTrackListener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.mu.Lock()
l.conns = append(l.conns, c)
l.mu.Unlock()
return c, nil
}
// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR
// (error code 0x1) on every tracked connection. This produces the exact error:
//
// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR
//
// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload):
//
// Length (3 bytes): 0x000004
// Type (1 byte): 0x03 (RST_STREAM)
// Flags (1 byte): 0x00
// Stream ID (4 bytes): target stream (must have bit 31 clear)
// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR)
func (l *connTrackListener) connCount() int {
l.mu.Lock()
defer l.mu.Unlock()
return len(l.conns)
}
func (l *connTrackListener) sendRSTStream(streamID uint32) {
l.mu.Lock()
defer l.mu.Unlock()
frame := make([]byte, 13) // 9-byte header + 4-byte payload
// Length = 4 (3 bytes, big-endian)
frame[0], frame[1], frame[2] = 0, 0, 4
// Type = RST_STREAM (0x03)
frame[3] = 0x03
// Flags = 0
frame[4] = 0x00
// Stream ID (4 bytes, big-endian, bit 31 reserved = 0)
binary.BigEndian.PutUint32(frame[5:9], streamID)
// Error Code = PROTOCOL_ERROR (0x1)
binary.BigEndian.PutUint32(frame[9:13], 0x1)
for _, c := range l.conns {
_, _ = c.Write(frame)
}
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
rawListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
listener := &connTrackListener{Listener: rawListener}
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: rawListener.Addr().String(),
listener: listener,
closeStream: make(chan struct{}, 1),
handlerDone: make(chan struct{}, 10),
handlerStarted: make(chan struct{}, 10),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
@@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer {
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
defer func() {
select {
case s.handlerDone <- struct{}{}:
default:
}
}()
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
select {
case s.handlerStarted <- struct{}{}:
default:
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
@@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
if err := stream.Send(ack); err != nil {
return err
}
case <-s.closeStream:
return status.Errorf(codes.Internal, "server closing stream")
case <-ctx.Done():
return ctx.Err()
}
@@ -110,16 +197,13 @@ func TestReceive(t *testing.T) {
assert.NoError(t, err, "failed to close flow")
})
receivedAcks := make(map[string]bool)
var ackCount atomic.Int32
receiveDone := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator && len(msg.EventId) > 0 {
id := string(msg.EventId)
receivedAcks[id] = true
if len(receivedAcks) >= 3 {
if ackCount.Add(1) >= 3 {
close(receiveDone)
}
}
@@ -130,7 +214,11 @@ func TestReceive(t *testing.T) {
}
}()
time.Sleep(500 * time.Millisecond)
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
for i := 0; i < 3; i++ {
eventID := uuid.New().String()
@@ -153,7 +241,7 @@ func TestReceive(t *testing.T) {
t.Fatal("timeout waiting for acks to be processed")
}
assert.Equal(t, 3, len(receivedAcks))
assert.Equal(t, int32(3), ackCount.Load())
}
func TestReceive_ContextCancellation(t *testing.T) {
@@ -254,3 +342,195 @@ func TestSend(t *testing.T) {
t.Fatal("timeout waiting for ack to be received by flow")
}
}
func TestNewClient_PermanentClose(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
err = client.Close()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
done := make(chan error, 1)
go func() {
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
}()
select {
case err := <-done:
require.ErrorIs(t, err, flow.ErrClientClosed)
case <-time.After(2 * time.Second):
t.Fatal("Receive did not return after Close — stuck in retry loop")
}
}
func TestNewClient_CloseVerify(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
done := make(chan error, 1)
go func() {
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
}()
closeDone := make(chan struct{}, 1)
go func() {
_ = client.Close()
closeDone <- struct{}{}
}()
select {
case err := <-done:
require.Error(t, err)
case <-time.After(2 * time.Second):
t.Fatal("Receive did not return after Close — stuck in retry loop")
}
select {
case <-closeDone:
return
case <-time.After(2 * time.Second):
t.Fatal("Close did not return — blocked in retry loop")
}
}
func TestClose_WhileReceiving(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
ctx := context.Background() // no timeout — intentional
receiveDone := make(chan struct{})
go func() {
_ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
return nil
})
close(receiveDone)
}()
// Wait for the server-side handler to confirm the stream is established.
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
closeDone := make(chan struct{})
go func() {
_ = client.Close()
close(closeDone)
}()
select {
case <-closeDone:
// Close returned — good
case <-time.After(2 * time.Second):
t.Fatal("Close blocked forever — Receive stuck in retry loop")
}
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Fatal("Receive did not exit after Close")
}
}
func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
// Track acks received before and after server-side stream close
var ackCount atomic.Int32
receivedFirst := make(chan struct{})
receivedAfterReconnect := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 {
return nil
}
n := ackCount.Add(1)
if n == 1 {
close(receivedFirst)
}
if n == 2 {
close(receivedAfterReconnect)
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
// Wait for stream to be established, then send first ack
select {
case <-server.handlerStarted:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for stream to be established")
}
server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")}
select {
case <-receivedFirst:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for first ack")
}
// Snapshot connection count before injecting the fault.
connsBefore := server.listener.connCount()
// Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection.
// gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated).
// Stream ID 1 is the client's first stream (our Events bidi stream).
// This produces the exact error the client sees in production:
// "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR"
server.listener.sendRSTStream(1)
// Wait for the old Events() handler to fully exit so it can no longer
// drain s.acks and drop our injected ack on a broken stream.
select {
case <-server.handlerDone:
case <-time.After(5 * time.Second):
t.Fatal("old Events() handler did not exit after RST_STREAM")
}
require.Eventually(t, func() bool {
return server.listener.connCount() > connsBefore
}, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")
server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")}
select {
case <-receivedAfterReconnect:
// Client successfully reconnected and received ack after server-side stream close
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect")
}
assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close")
assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)")
}

2
go.mod
View File

@@ -208,12 +208,14 @@ require (
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
github.com/jezek/xgb v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jonboulle/clockwork v0.5.0 // indirect
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/kirides/go-d3d v1.0.1 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect

4
go.sum
View File

@@ -309,6 +309,8 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
github.com/jezek/xgb v1.3.0 h1:Wa1pn4GVtcmNVAVB6/pnQVJ7xPFZVZ/W1Tc27msDhgI=
github.com/jezek/xgb v1.3.0/go.mod h1:nrhwO0FX/enq75I7Y7G8iN1ubpSGZEiA3v9e9GyRFlk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
@@ -325,6 +327,8 @@ github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6U
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
github.com/kirides/go-d3d v1.0.1 h1:ZDANfvo34vskBMET1uwUUMNw8545Kbe8qYSiRwlNIuA=
github.com/kirides/go-d3d v1.0.1/go.mod h1:99AjD+5mRTFEnkpRWkwq8UYMQDljGIIvLn2NyRdVImY=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=

View File

@@ -94,10 +94,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
sshConfig := &proto.SSHConfig{
SshEnabled: peer.SSHEnabled || enableSSH,
}
if sshConfig.SshEnabled {
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
JwtConfig: buildJWTConfig(httpConfig, deviceFlowConfig),
}
return &proto.PeerConfig{
@@ -156,19 +153,21 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.ForwardingRules = forwardingRules
}
if networkMap.TransparentProxyConfig != nil {
response.NetworkMap.TransparentProxyConfig = networkMap.TransparentProxyConfig.ToProto()
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
if networkMap.VNCAuthorizedUsers != nil {
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.VNCAuthorizedUsers)
response.NetworkMap.VncAuth = &proto.VNCAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
return response
}

View File

@@ -665,6 +665,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
ServerVNCAllowed: meta.GetFlags().GetServerVNCAllowed(),
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
DisableDNS: meta.GetFlags().GetDisableDNS(),
@@ -672,6 +673,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
BlockInbound: meta.GetFlags().GetBlockInbound(),
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
DisableVNCAuth: meta.GetFlags().GetDisableVNCAuth(),
},
Files: files,
}

View File

@@ -113,10 +113,6 @@ type Manager interface {
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetInspectionPolicy(ctx context.Context, accountID, policyID, userID string) (*types.InspectionPolicy, error)
SaveInspectionPolicy(ctx context.Context, accountID, userID string, policy *types.InspectionPolicy, create bool) (*types.InspectionPolicy, error)
DeleteInspectionPolicy(ctx context.Context, accountID, policyID, userID string) error
ListInspectionPolicies(ctx context.Context, accountID, userID string) ([]*types.InspectionPolicy, error)
GetIdpManager() idp.Manager
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)

Some files were not shown because too many files have changed in this diff Show More