mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 22:26:23 -04:00
Compare commits
3 Commits
transparen
...
vnc-server
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfe1bba287 | ||
|
|
13539543af | ||
|
|
7483fec048 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -33,5 +33,3 @@ infrastructure_files/setup-*.env
|
||||
vendor/
|
||||
/netbird
|
||||
client/netbird-electron/
|
||||
management/server/types/testdata/comparison/
|
||||
management/server/types/testdata/*.json
|
||||
|
||||
@@ -150,6 +150,7 @@ func init() {
|
||||
rootCmd.AddCommand(logoutCmd)
|
||||
rootCmd.AddCommand(versionCmd)
|
||||
rootCmd.AddCommand(sshCmd)
|
||||
rootCmd.AddCommand(vncCmd)
|
||||
rootCmd.AddCommand(networksCMD)
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
|
||||
@@ -36,7 +36,10 @@ const (
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
jwtCacheTTLFlag = "jwt-cache-ttl"
|
||||
|
||||
// Alias for backward compatibility.
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -61,7 +64,7 @@ var (
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
jwtCacheTTL int
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -71,7 +74,9 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
|
||||
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
|
||||
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
|
||||
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
|
||||
@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -371,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
req.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||
req.SshJWTCacheTTL = &jwtCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
@@ -458,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -479,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
ic.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &jwtCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
@@ -582,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -603,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||
loginRequest.DisableVNCAuth = &disableVNCAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
|
||||
271
client/cmd/vnc.go
Normal file
271
client/cmd/vnc.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
vncUsername string
|
||||
vncHost string
|
||||
vncMode string
|
||||
vncListen string
|
||||
vncNoBrowser bool
|
||||
vncNoCache bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
|
||||
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
|
||||
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
|
||||
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
}
|
||||
|
||||
var vncCmd = &cobra.Command{
|
||||
Use: "vnc [flags] [user@]host",
|
||||
Short: "Connect to a NetBird peer via VNC",
|
||||
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
|
||||
The target peer must have the VNC server enabled.
|
||||
|
||||
Two modes are available:
|
||||
- attach: view the current physical display (remote support)
|
||||
- session: start a virtual desktop as the specified user (passwordless login)
|
||||
|
||||
Use --listen to start a local proxy for external VNC viewers:
|
||||
netbird vnc --listen :5900 peer-hostname
|
||||
vncviewer localhost:5900
|
||||
|
||||
Examples:
|
||||
netbird vnc peer-hostname
|
||||
netbird vnc --mode session --user alice peer-hostname
|
||||
netbird vnc --listen :5900 peer-hostname`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: vncFn,
|
||||
}
|
||||
|
||||
func vncFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
if err := parseVNCHostArg(args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
vncCtx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runVNC(vncCtx, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-vncCtx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-vncCtx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCHostArg(arg string) error {
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return fmt.Errorf("invalid user@host format")
|
||||
}
|
||||
if vncUsername == "" {
|
||||
vncUsername = parts[0]
|
||||
}
|
||||
vncHost = parts[1]
|
||||
if vncMode == "attach" {
|
||||
vncMode = "session"
|
||||
}
|
||||
} else {
|
||||
vncHost = arg
|
||||
}
|
||||
|
||||
if vncMode == "session" && vncUsername == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
vncUsername = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
vncUsername = currentUser.Username
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runVNC(ctx context.Context, cmd *cobra.Command) error {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||
|
||||
if vncMode == "session" {
|
||||
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
|
||||
} else {
|
||||
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
|
||||
}
|
||||
|
||||
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
|
||||
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
|
||||
var jwtToken string
|
||||
hint := profilemanager.GetLoginHint()
|
||||
var browserOpener func(string) error
|
||||
if !vncNoBrowser {
|
||||
browserOpener = util.OpenBrowser
|
||||
}
|
||||
|
||||
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
|
||||
if err != nil {
|
||||
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
|
||||
} else {
|
||||
jwtToken = token
|
||||
log.Debug("JWT authentication successful")
|
||||
}
|
||||
|
||||
// Connect to the VNC server on the standard port (5900). The peer's firewall
|
||||
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
|
||||
vncAddr := net.JoinHostPort(vncHost, "5900")
|
||||
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
|
||||
}
|
||||
defer vncConn.Close()
|
||||
|
||||
// Send session header with mode, username, and JWT.
|
||||
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
|
||||
return fmt.Errorf("send VNC header: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("VNC connected to %s\n", vncHost)
|
||||
|
||||
if vncListen != "" {
|
||||
return runVNCLocalProxy(ctx, cmd, vncConn)
|
||||
}
|
||||
|
||||
// No --listen flag: inform the user they need to use --listen for external viewers.
|
||||
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
|
||||
cmd.Printf("Press Ctrl+C to disconnect.\n")
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
const vncDialTimeout = 15 * time.Second
|
||||
|
||||
// sendVNCHeader writes the NetBird VNC session header.
|
||||
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
|
||||
var modeByte byte
|
||||
if mode == "session" {
|
||||
modeByte = 1
|
||||
}
|
||||
|
||||
usernameBytes := []byte(username)
|
||||
jwtBytes := []byte(jwt)
|
||||
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
|
||||
hdr[0] = modeByte
|
||||
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
|
||||
off := 3
|
||||
copy(hdr[off:], usernameBytes)
|
||||
off += len(usernameBytes)
|
||||
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
|
||||
off += 2
|
||||
copy(hdr[off:], jwtBytes)
|
||||
|
||||
_, err := conn.Write(hdr)
|
||||
return err
|
||||
}
|
||||
|
||||
// runVNCLocalProxy listens on the given address and proxies incoming
|
||||
// connections to the already-established VNC tunnel.
|
||||
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
|
||||
listener, err := net.Listen("tcp", vncListen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", vncListen, err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
|
||||
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
// Accept a single viewer connection. VNC is single-session: the RFB
|
||||
// handshake completes on vncConn for the first viewer, so subsequent
|
||||
// viewers would get a mid-stream connection. The loop handles transient
|
||||
// accept errors until a valid connection arrives.
|
||||
for {
|
||||
clientConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
log.Debugf("accept VNC proxy client: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
|
||||
|
||||
// Bidirectional copy.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
io.Copy(vncConn, clientConn)
|
||||
close(done)
|
||||
}()
|
||||
io.Copy(clientConn, vncConn)
|
||||
<-done
|
||||
clientConn.Close()
|
||||
|
||||
cmd.Printf("VNC viewer disconnected\n")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
62
client/cmd/vnc_agent.go
Normal file
62
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,62 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var vncAgentPort string
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server in the current user session, listening on
|
||||
// localhost. It is spawned by the NetBird service (Session 0) via
|
||||
// CreateProcessAsUser into the interactive console session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Agent's stderr is piped to the service which relogs it.
|
||||
// Use JSON format with caller info for structured parsing.
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
|
||||
|
||||
capturer := vncserver.NewDesktopCapturer()
|
||||
injector := vncserver.NewWindowsInputInjector()
|
||||
srv := vncserver.New(capturer, injector, "")
|
||||
// Auth is handled by the service. The agent verifies a token on each
|
||||
// connection to ensure only the service process can connect.
|
||||
// The token is passed via environment variable to avoid exposing it
|
||||
// in the process command line (visible via tasklist/wmic).
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
|
||||
|
||||
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
<-cmd.Context().Done()
|
||||
return srv.Stop()
|
||||
},
|
||||
}
|
||||
16
client/cmd/vnc_flags.go
Normal file
16
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package cmd
|
||||
|
||||
const (
|
||||
serverVNCAllowedFlag = "allow-server-vnc"
|
||||
disableVNCAuthFlag = "disable-vnc-auth"
|
||||
)
|
||||
|
||||
var (
|
||||
serverVNCAllowed bool
|
||||
disableVNCAuth bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
|
||||
}
|
||||
@@ -364,28 +364,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
|
||||
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID.
|
||||
func (m *Manager) RemoveTProxyRule(ruleID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveTProxyRule(ruleID)
|
||||
}
|
||||
|
||||
// AddUDPInspectionHook is a no-op for iptables (kernel-mode firewall has no userspace packet hooks).
|
||||
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
|
||||
|
||||
// RemoveUDPInspectionHook is a no-op for iptables.
|
||||
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
|
||||
|
||||
func (m *Manager) initNoTrackChain() error {
|
||||
if err := m.cleanupNoTrackChain(); err != nil {
|
||||
log.Debugf("cleanup notrack chain: %v", err)
|
||||
|
||||
@@ -89,8 +89,6 @@ type router struct {
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
|
||||
tproxyRules []tproxyRuleEntry
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
@@ -1111,92 +1109,3 @@ func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
func (r *router) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
|
||||
// AddTProxyRule adds iptables nat PREROUTING REDIRECT rules for transparent proxy interception.
|
||||
// Traffic from sources on dstPorts arriving on the WG interface is redirected
|
||||
// to the transparent proxy listener on redirectPort.
|
||||
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
portStr := fmt.Sprintf("%d", redirectPort)
|
||||
|
||||
for _, proto := range []string{"tcp", "udp"} {
|
||||
srcSpecs := r.buildSourceSpecs(sources)
|
||||
|
||||
for _, srcSpec := range srcSpecs {
|
||||
if len(dstPorts) == 0 {
|
||||
rule := append(srcSpec,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "REDIRECT",
|
||||
"--to-ports", portStr,
|
||||
)
|
||||
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
|
||||
return fmt.Errorf("add redirect rule %s/%s: %w", ruleID, proto, err)
|
||||
}
|
||||
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
|
||||
ruleID: ruleID,
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
spec: rule,
|
||||
})
|
||||
} else {
|
||||
for _, port := range dstPorts {
|
||||
rule := append(srcSpec,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"--dport", fmt.Sprintf("%d", port),
|
||||
"-j", "REDIRECT",
|
||||
"--to-ports", portStr,
|
||||
)
|
||||
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
|
||||
return fmt.Errorf("add redirect rule %s/%s/%d: %w", ruleID, proto, port, err)
|
||||
}
|
||||
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
|
||||
ruleID: ruleID,
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
spec: rule,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes all iptables REDIRECT rules for the given ruleID.
|
||||
func (r *router) RemoveTProxyRule(ruleID string) error {
|
||||
var remaining []tproxyRuleEntry
|
||||
for _, entry := range r.tproxyRules {
|
||||
if entry.ruleID != ruleID {
|
||||
remaining = append(remaining, entry)
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(entry.table, entry.chain, entry.spec...); err != nil {
|
||||
log.Debugf("remove tproxy rule %s: %v", ruleID, err)
|
||||
}
|
||||
}
|
||||
r.tproxyRules = remaining
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type tproxyRuleEntry struct {
|
||||
ruleID string
|
||||
table string
|
||||
chain string
|
||||
spec []string
|
||||
}
|
||||
|
||||
func (r *router) buildSourceSpecs(sources []netip.Prefix) [][]string {
|
||||
if len(sources) == 0 {
|
||||
return [][]string{{}} // empty spec = match any source
|
||||
}
|
||||
|
||||
specs := make([][]string, 0, len(sources))
|
||||
for _, src := range sources {
|
||||
specs = append(specs, []string{"-s", src.String()})
|
||||
}
|
||||
return specs
|
||||
}
|
||||
|
||||
|
||||
@@ -180,22 +180,6 @@ type Manager interface {
|
||||
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
|
||||
// This prevents conntrack from interfering with WireGuard proxy communication.
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
|
||||
// AddTProxyRule adds TPROXY redirect rules for specific source CIDRs and destination ports.
|
||||
// Traffic from sources on dstPorts is redirected to the transparent proxy on redirectPort.
|
||||
// Empty dstPorts means redirect all ports.
|
||||
AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID.
|
||||
RemoveTProxyRule(ruleID string) error
|
||||
|
||||
// AddUDPInspectionHook registers a hook that inspects UDP packets before forwarding.
|
||||
// The hook receives the raw packet and returns true to drop it.
|
||||
// Used for QUIC SNI-based blocking. Returns a hook ID for removal.
|
||||
AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string
|
||||
|
||||
// RemoveUDPInspectionHook removes a previously registered inspection hook.
|
||||
RemoveUDPInspectionHook(hookID string)
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
|
||||
@@ -482,28 +482,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
|
||||
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID.
|
||||
func (m *Manager) RemoveTProxyRule(ruleID string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.router.RemoveTProxyRule(ruleID)
|
||||
}
|
||||
|
||||
// AddUDPInspectionHook is a no-op for nftables (kernel-mode firewall has no userspace packet hooks).
|
||||
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
|
||||
|
||||
// RemoveUDPInspectionHook is a no-op for nftables.
|
||||
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
|
||||
|
||||
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
|
||||
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRawOutput,
|
||||
|
||||
@@ -77,7 +77,6 @@ type router struct {
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
@@ -2138,227 +2137,3 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddTProxyRule adds nftables TPROXY redirect rules in the mangle prerouting chain.
|
||||
// Traffic from sources on dstPorts arriving on the WG interface is redirected to
|
||||
// the transparent proxy listener on redirectPort.
|
||||
// Separate rules are created for TCP and UDP protocols.
|
||||
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
// Use the nat redirect chain for DNAT rules.
|
||||
// TPROXY doesn't work on WG kernel interfaces (socket assignment silently fails),
|
||||
// so we use DNAT to 127.0.0.1:proxy_port instead. The proxy reads the original
|
||||
// destination via SO_ORIGINAL_DST (conntrack).
|
||||
chain := r.chains[chainNameRoutingRdr]
|
||||
if chain == nil {
|
||||
return fmt.Errorf("nat redirect chain not initialized")
|
||||
}
|
||||
|
||||
for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP} {
|
||||
protoName := "tcp"
|
||||
if proto == unix.IPPROTO_UDP {
|
||||
protoName = "udp"
|
||||
}
|
||||
|
||||
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, protoName)
|
||||
|
||||
if existing, ok := r.rules[ruleKey]; ok && existing.Handle != 0 {
|
||||
if err := r.decrementSetCounter(existing); err != nil {
|
||||
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
|
||||
}
|
||||
if err := r.conn.DelRule(existing); err != nil {
|
||||
log.Debugf("remove existing tproxy rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
exprs, err := r.buildRedirectExprs(proto, sources, dstPorts, redirectPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build redirect exprs for %s: %w", protoName, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
}
|
||||
|
||||
// Accept redirected packets in the ACL input chain. After REDIRECT, the
|
||||
// destination port becomes the proxy port. Without this rule, the ACL filter
|
||||
// drops the packet. We match on ct state dnat so only REDIRECT'd connections
|
||||
// are accepted: direct connections to the proxy port are blocked.
|
||||
inputAcceptKey := fmt.Sprintf("tproxy-%s-input", ruleID)
|
||||
if _, ok := r.rules[inputAcceptKey]; !ok {
|
||||
inputChain := &nftables.Chain{
|
||||
Name: "netbird-acl-input-rules",
|
||||
Table: r.workTable,
|
||||
}
|
||||
r.rules[inputAcceptKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: inputChain,
|
||||
Exprs: []expr.Any{
|
||||
// Only accept connections that were REDIRECT'd (ct status dnat)
|
||||
&expr.Ct{Register: 1, Key: expr.CtKeySTATUS},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(0x20), // IPS_DST_NAT
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
// Accept both TCP and UDP redirected to the proxy port.
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(redirectPort),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(inputAcceptKey),
|
||||
})
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush tproxy rules for %s: %w", ruleID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTProxyRule removes TPROXY redirect rules by ID (both TCP and UDP variants).
|
||||
func (r *router) RemoveTProxyRule(ruleID string) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var removed int
|
||||
for _, suffix := range []string{"tcp", "udp", "input"} {
|
||||
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, suffix)
|
||||
|
||||
rule, ok := r.rules[ruleKey]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
delete(r.rules, ruleKey)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
log.Debugf("delete tproxy rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
removed++
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush tproxy rule removal for %s: %w", ruleID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildRedirectExprs builds nftables expressions for a REDIRECT rule.
|
||||
// Matches WG interface ingress, source CIDRs, destination ports, then REDIRECTs to the proxy port.
|
||||
func (r *router) buildRedirectExprs(proto uint8, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) ([]expr.Any, error) {
|
||||
var exprs []expr.Any
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname(r.wgIface.Name())},
|
||||
)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
|
||||
)
|
||||
|
||||
// Source CIDRs use the named ipset shared with route rules.
|
||||
if len(sources) > 0 {
|
||||
srcSet := firewall.NewPrefixSet(sources)
|
||||
srcExprs, err := r.getIpSet(srcSet, sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get source ipset: %w", err)
|
||||
}
|
||||
exprs = append(exprs, srcExprs...)
|
||||
}
|
||||
|
||||
if len(dstPorts) == 1 {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(dstPorts[0]),
|
||||
},
|
||||
)
|
||||
} else if len(dstPorts) > 1 {
|
||||
setElements := make([]nftables.SetElement, len(dstPorts))
|
||||
for i, p := range dstPorts {
|
||||
setElements[i] = nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)}
|
||||
}
|
||||
portSet := &nftables.Set{
|
||||
Table: r.workTable,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetService,
|
||||
}
|
||||
if err := r.conn.AddSet(portSet, setElements); err != nil {
|
||||
return nil, fmt.Errorf("create port set: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: portSet.Name,
|
||||
SetID: portSet.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// REDIRECT to local proxy port. Changes the destination to the interface's
|
||||
// primary address + specified port. Conntrack tracks the original destination,
|
||||
// readable via SO_ORIGINAL_DST.
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(redirectPort)},
|
||||
&expr.Redir{
|
||||
RegisterProtoMin: 1,
|
||||
},
|
||||
)
|
||||
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -641,45 +641,6 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||
}
|
||||
|
||||
// AddTProxyRule delegates to the native firewall for TPROXY rules.
|
||||
// In userspace mode (no native firewall), this is a no-op since the
|
||||
// forwarder intercepts traffic directly.
|
||||
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
|
||||
}
|
||||
|
||||
// AddUDPInspectionHook registers a hook for QUIC/UDP inspection via the packet filter.
|
||||
func (m *Manager) AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string {
|
||||
m.SetUDPPacketHook(netip.Addr{}, dstPort, hook)
|
||||
return "udp-inspection"
|
||||
}
|
||||
|
||||
// RemoveUDPInspectionHook removes a previously registered inspection hook.
|
||||
func (m *Manager) RemoveUDPInspectionHook(_ string) {
|
||||
m.SetUDPPacketHook(netip.Addr{}, 0, nil)
|
||||
}
|
||||
|
||||
// RemoveTProxyRule delegates to the native firewall for TPROXY rules.
|
||||
func (m *Manager) RemoveTProxyRule(ruleID string) error {
|
||||
if m.nativeFirewall == nil {
|
||||
return nil
|
||||
}
|
||||
return m.nativeFirewall.RemoveTProxyRule(ruleID)
|
||||
}
|
||||
|
||||
// IsLocalIP reports whether the given IP belongs to the local machine.
|
||||
func (m *Manager) IsLocalIP(ip netip.Addr) bool {
|
||||
return m.localipmanager.IsLocalIP(ip)
|
||||
}
|
||||
|
||||
// GetForwarder returns the userspace packet forwarder, or nil if not initialized.
|
||||
func (m *Manager) GetForwarder() *forwarder.Forwarder {
|
||||
return m.forwarder.Load()
|
||||
}
|
||||
|
||||
// UpdateSet updates the rule destinations associated with the given set
|
||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -22,7 +21,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
@@ -48,10 +46,6 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
// proxy is the optional inspection engine.
|
||||
// When set, TCP connections are handed to the engine for protocol detection
|
||||
// and rule evaluation. Swapped atomically for lock-free hot-path access.
|
||||
proxy atomic.Pointer[inspect.Proxy]
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -85,7 +79,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("add protocol address: %s", err)
|
||||
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||
}
|
||||
|
||||
defaultSubnet, err := tcpip.NewSubnet(
|
||||
@@ -161,13 +155,6 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetProxy sets the inspection engine. When set, TCP connections are handed
|
||||
// to it for protocol detection and rule evaluation instead of direct relay.
|
||||
// Pass nil to disable inspection.
|
||||
func (f *Forwarder) SetProxy(p *inspect.Proxy) {
|
||||
f.proxy.Store(p)
|
||||
}
|
||||
|
||||
// Stop gracefully shuts down the forwarder
|
||||
func (f *Forwarder) Stop() {
|
||||
f.cancel()
|
||||
@@ -180,25 +167,6 @@ func (f *Forwarder) Stop() {
|
||||
f.stack.Wait()
|
||||
}
|
||||
|
||||
// CheckUDPPacket inspects a UDP payload against proxy rules before injection.
|
||||
// This is called by the filter for QUIC SNI-based blocking.
|
||||
// Returns true if the packet should be allowed, false if it should be dropped.
|
||||
func (f *Forwarder) CheckUDPPacket(payload []byte, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) bool {
|
||||
p := f.proxy.Load()
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
dst := netip.AddrPortFrom(dstIP, dstPort)
|
||||
src := inspect.SourceInfo{
|
||||
IP: srcIP,
|
||||
PolicyID: inspect.PolicyID(ruleID),
|
||||
}
|
||||
|
||||
action := p.HandleUDPPacket(payload, dst, src)
|
||||
return action != inspect.ActionBlock
|
||||
}
|
||||
|
||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||
if f.netstack && f.ip.Equal(addr) {
|
||||
return net.IPv4(127, 0, 0, 1)
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
@@ -24,86 +23,6 @@ import (
|
||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
id := r.ID()
|
||||
|
||||
// If the inspection engine is configured, accept the connection first and hand it off.
|
||||
if p := f.proxy.Load(); p != nil {
|
||||
f.handleTCPWithInspection(r, id, p)
|
||||
return
|
||||
}
|
||||
|
||||
f.handleTCPDirect(r, id)
|
||||
}
|
||||
|
||||
// handleTCPWithInspection accepts the connection and hands it to the inspection
|
||||
// engine. For allow decisions, the forwarder does its own relay (passthrough).
|
||||
// For block/inspect, the engine handles everything internally.
|
||||
func (f *Forwarder) handleTCPWithInspection(r *tcp.ForwarderRequest, id stack.TransportEndpointID, p *inspect.Proxy) {
|
||||
flowID := uuid.New()
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
|
||||
wq := waiter.Queue{}
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
if epErr != nil {
|
||||
f.logger.Error1("forwarder: create TCP endpoint for inspection: %v", epErr)
|
||||
r.Complete(true)
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
srcIP := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIP := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
dst := netip.AddrPortFrom(dstIP, id.LocalPort)
|
||||
|
||||
var policyID []byte
|
||||
if ruleID, ok := f.getRuleID(srcIP, dstIP, id.RemotePort, id.LocalPort); ok {
|
||||
policyID = ruleID
|
||||
}
|
||||
|
||||
src := inspect.SourceInfo{
|
||||
IP: srcIP,
|
||||
PolicyID: inspect.PolicyID(policyID),
|
||||
}
|
||||
|
||||
f.logger.Trace1("forwarder: handing TCP %v to inspection engine", epID(id))
|
||||
|
||||
go func() {
|
||||
result, err := p.InspectTCP(f.ctx, inConn, dst, src)
|
||||
if err != nil && err != inspect.ErrBlocked {
|
||||
f.logger.Debug2("forwarder: inspection error for %v: %v", epID(id), err)
|
||||
}
|
||||
|
||||
// Passthrough: engine returned allow, forwarder does the relay.
|
||||
if result.PassthroughConn != nil {
|
||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||
outConn, dialErr := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if dialErr != nil {
|
||||
f.logger.Trace2("forwarder: passthrough dial error for %v: %v", epID(id), dialErr)
|
||||
if closeErr := result.PassthroughConn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: close passthrough conn: %v", closeErr)
|
||||
}
|
||||
ep.Close()
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||
return
|
||||
}
|
||||
f.proxyTCPPassthrough(id, result.PassthroughConn, outConn, ep, flowID)
|
||||
return
|
||||
}
|
||||
|
||||
// Engine handled it (block/inspect/HTTP). Capture stats and clean up.
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
ep.Close()
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, rxPackets, txPackets)
|
||||
}()
|
||||
}
|
||||
|
||||
// handleTCPDirect handles TCP connections with direct relay (no proxy).
|
||||
func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportEndpointID) {
|
||||
flowID := uuid.New()
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||
@@ -123,6 +42,7 @@ func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportE
|
||||
return
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
wq := waiter.Queue{}
|
||||
|
||||
ep, epErr := r.CreateEndpoint(&wq)
|
||||
@@ -135,6 +55,7 @@ func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportE
|
||||
return
|
||||
}
|
||||
|
||||
// Complete the handshake
|
||||
r.Complete(false)
|
||||
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
@@ -152,6 +73,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
@@ -210,66 +132,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
// proxyTCPPassthrough relays traffic between a peeked inbound connection
|
||||
// (from the inspection engine passthrough) and the outbound connection.
|
||||
// It accepts net.Conn for inConn since the inspection engine wraps it in a peekConn.
|
||||
func (f *Forwarder) proxyTCPPassthrough(id stack.TransportEndpointID, inConn net.Conn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: passthrough inConn close: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: passthrough outConn close: %v", err)
|
||||
}
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesIn int64
|
||||
bytesOut int64
|
||||
errIn error
|
||||
errOut error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesIn, errIn = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
bytesOut, errOut = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errIn != nil && !isClosedError(errIn) {
|
||||
f.logger.Error2("proxyTCPPassthrough: copy error (in→out) for %s: %v", epID(id), errIn)
|
||||
}
|
||||
if errOut != nil && !isClosedError(errOut) {
|
||||
f.logger.Error2("proxyTCPPassthrough: copy error (out→in) for %s: %v", epID(id), errOut)
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
rxPackets = tcpStats.SegmentsSent.Value()
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace5("forwarder: passthrough TCP %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesOut, txPackets, bytesIn)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesOut), uint64(bytesIn), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// InspectResult holds the outcome of connection inspection.
|
||||
type InspectResult struct {
|
||||
// Action is the rule evaluation result.
|
||||
Action Action
|
||||
// PassthroughConn is the client connection with buffered peeked bytes.
|
||||
// Non-nil only when Action is ActionAllow and the caller should relay
|
||||
// (TLS passthrough or non-HTTP/TLS protocol). The caller takes ownership
|
||||
// and is responsible for closing this connection.
|
||||
PassthroughConn net.Conn
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultTProxyPort is the default TPROXY listener port for kernel mode.
|
||||
// Override with NB_TPROXY_PORT environment variable.
|
||||
DefaultTProxyPort = 22080
|
||||
)
|
||||
|
||||
// Action determines how the proxy handles a matched connection.
|
||||
type Action string
|
||||
|
||||
const (
|
||||
// ActionAllow passes the connection through without decryption.
|
||||
ActionAllow Action = "allow"
|
||||
// ActionBlock denies the connection.
|
||||
ActionBlock Action = "block"
|
||||
// ActionInspect decrypts (MITM) and inspects the connection.
|
||||
ActionInspect Action = "inspect"
|
||||
)
|
||||
|
||||
// ProxyMode determines the proxy operating mode.
|
||||
type ProxyMode string
|
||||
|
||||
const (
|
||||
// ModeBuiltin uses the built-in proxy with rules and optional ICAP.
|
||||
ModeBuiltin ProxyMode = "builtin"
|
||||
// ModeEnvoy runs a local envoy sidecar for L7 processing.
|
||||
// Go manages envoy lifecycle, config generation, and rule evaluation.
|
||||
// USP path forwards via PROXY protocol v2; kernel path uses nftables redirect.
|
||||
ModeEnvoy ProxyMode = "envoy"
|
||||
// ModeExternal forwards all traffic to an external proxy.
|
||||
ModeExternal ProxyMode = "external"
|
||||
)
|
||||
|
||||
// PolicyID is the management policy identifier associated with a connection.
|
||||
type PolicyID []byte
|
||||
|
||||
// MatchDomain reports whether target matches the pattern.
|
||||
// If pattern starts with "*.", it matches any subdomain (but not the base itself).
|
||||
// Otherwise it requires an exact match.
|
||||
func MatchDomain(pattern, target domain.Domain) bool {
|
||||
p := pattern.PunycodeString()
|
||||
t := target.PunycodeString()
|
||||
|
||||
if strings.HasPrefix(p, "*.") {
|
||||
base := p[2:]
|
||||
return strings.HasSuffix(t, "."+base)
|
||||
}
|
||||
|
||||
return p == t
|
||||
}
|
||||
|
||||
// SourceInfo carries source identity context for rule evaluation.
|
||||
// The source may be a direct WireGuard peer or a host behind
|
||||
// a site-to-site gateway.
|
||||
type SourceInfo struct {
|
||||
// IP is the original source address from the packet.
|
||||
IP netip.Addr
|
||||
// PolicyID is the management policy that allowed this traffic
|
||||
// through route ACLs.
|
||||
PolicyID PolicyID
|
||||
}
|
||||
|
||||
// ProtoType identifies a protocol handled by the proxy.
|
||||
type ProtoType string
|
||||
|
||||
const (
|
||||
ProtoHTTP ProtoType = "http"
|
||||
ProtoHTTPS ProtoType = "https"
|
||||
ProtoH2 ProtoType = "h2"
|
||||
ProtoH3 ProtoType = "h3"
|
||||
ProtoWebSocket ProtoType = "websocket"
|
||||
ProtoOther ProtoType = "other"
|
||||
)
|
||||
|
||||
// Rule defines a proxy inspection/filtering rule.
|
||||
type Rule struct {
|
||||
// ID uniquely identifies this rule.
|
||||
ID id.RuleID
|
||||
// Sources are the source CIDRs this rule applies to.
|
||||
// Includes both direct peer IPs and routed networks behind gateways.
|
||||
Sources []netip.Prefix
|
||||
// Domains are the destination domain patterns to match (via SNI or Host header).
|
||||
// Supports exact match ("example.com") and wildcard ("*.example.com").
|
||||
Domains []domain.Domain
|
||||
// Networks are the destination CIDRs to match.
|
||||
Networks []netip.Prefix
|
||||
// Ports are the destination ports to match. Empty means all ports.
|
||||
Ports []uint16
|
||||
// Protocols restricts which protocols this rule applies to.
|
||||
// Empty means all protocols.
|
||||
Protocols []ProtoType
|
||||
// Paths are URL path patterns to match (HTTP only, requires inspect for HTTPS).
|
||||
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
|
||||
// Empty means all paths.
|
||||
Paths []string
|
||||
// Action determines what to do with matched connections.
|
||||
Action Action
|
||||
// Priority controls evaluation order. Lower values are evaluated first.
|
||||
Priority int
|
||||
}
|
||||
|
||||
// ICAPConfig holds ICAP service configuration.
|
||||
type ICAPConfig struct {
|
||||
// ReqModURL is the ICAP REQMOD service URL (e.g., icap://server:1344/reqmod).
|
||||
ReqModURL *url.URL
|
||||
// RespModURL is the ICAP RESPMOD service URL (e.g., icap://server:1344/respmod).
|
||||
RespModURL *url.URL
|
||||
// MaxConnections is the connection pool size. Zero uses a default.
|
||||
MaxConnections int
|
||||
}
|
||||
|
||||
// TLSConfig holds the MITM CA configuration for TLS inspection.
|
||||
type TLSConfig struct {
|
||||
// CA is the certificate authority used to sign dynamic certificates.
|
||||
CA *x509.Certificate
|
||||
// CAKey is the CA's private key.
|
||||
CAKey crypto.PrivateKey
|
||||
}
|
||||
|
||||
// Config holds the transparent proxy configuration.
|
||||
type Config struct {
|
||||
// Enabled controls whether the proxy is active.
|
||||
Enabled bool
|
||||
// Mode selects built-in or external proxy operation.
|
||||
Mode ProxyMode
|
||||
// ExternalURL is the upstream proxy URL for ModeExternal.
|
||||
// Supports http:// and socks5:// schemes.
|
||||
ExternalURL *url.URL
|
||||
|
||||
// DefaultAction applies when no rule matches a connection.
|
||||
DefaultAction Action
|
||||
|
||||
// RedirectSources are the source CIDRs whose traffic should be intercepted.
|
||||
// Admin decides: "activate for these users/subnets."
|
||||
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
|
||||
RedirectSources []netip.Prefix
|
||||
// RedirectPorts are the destination ports to intercept. Empty means all ports.
|
||||
RedirectPorts []uint16
|
||||
|
||||
// Rules are the proxy inspection/filtering rules, evaluated in Priority order.
|
||||
Rules []Rule
|
||||
|
||||
// ICAP holds ICAP service configuration. Nil disables ICAP.
|
||||
ICAP *ICAPConfig
|
||||
// TLS holds the MITM CA. Nil means no MITM capability (ActionInspect rules ignored).
|
||||
TLS *TLSConfig
|
||||
|
||||
// Envoy configuration (ModeEnvoy only)
|
||||
Envoy *EnvoyConfig
|
||||
|
||||
// ListenAddr is the TPROXY listen address for kernel mode.
|
||||
// Zero value disables the TPROXY listener.
|
||||
ListenAddr netip.AddrPort
|
||||
// WGNetwork is the WireGuard overlay network prefix.
|
||||
// The proxy blocks dialing destinations inside this network.
|
||||
WGNetwork netip.Prefix
|
||||
// LocalIPChecker reports whether an IP belongs to the routing peer.
|
||||
// Used to prevent SSRF to local services. May be nil.
|
||||
LocalIPChecker LocalIPChecker
|
||||
}
|
||||
|
||||
// EnvoyConfig holds configuration for the envoy sidecar mode.
|
||||
type EnvoyConfig struct {
|
||||
// BinaryPath is the path to the envoy binary.
|
||||
// Empty means search $PATH for "envoy".
|
||||
BinaryPath string
|
||||
// AdminPort is the port for envoy's admin API (health checks, stats).
|
||||
// Zero means auto-assign.
|
||||
AdminPort uint16
|
||||
// Snippets are user-provided config fragments merged into the generated bootstrap.
|
||||
Snippets *EnvoySnippets
|
||||
}
|
||||
|
||||
// EnvoySnippets holds user-provided YAML fragments for envoy config customization.
|
||||
// Only safe snippet types are allowed: filters (HTTP and network) and clusters
|
||||
// needed as dependencies for filter services. Listeners and bootstrap overrides
|
||||
// are not exposed since we manage the listener and bootstrap.
|
||||
type EnvoySnippets struct {
|
||||
// HTTPFilters is YAML injected into the HCM filter chain before the router filter.
|
||||
// Used for ext_authz, rate limiting, Lua, Wasm, RBAC, JWT auth, etc.
|
||||
HTTPFilters string
|
||||
// NetworkFilters is YAML injected into the TLS filter chain before tcp_proxy.
|
||||
// Used for network-level RBAC, rate limiting, ext_authz on raw TCP.
|
||||
NetworkFilters string
|
||||
// Clusters is YAML for additional upstream clusters referenced by filters.
|
||||
// Needed when filters call external services (ext_authz backend, rate limit service).
|
||||
Clusters string
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func TestMatchDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
target string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
pattern: "example.com",
|
||||
target: "example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact no match",
|
||||
pattern: "example.com",
|
||||
target: "other.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard matches subdomain",
|
||||
pattern: "*.example.com",
|
||||
target: "foo.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard matches deep subdomain",
|
||||
pattern: "*.example.com",
|
||||
target: "a.b.c.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard does not match base",
|
||||
pattern: "*.example.com",
|
||||
target: "example.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard does not match unrelated",
|
||||
pattern: "*.example.com",
|
||||
target: "foo.other.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive exact match",
|
||||
pattern: "Example.COM",
|
||||
target: "example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "case insensitive wildcard match",
|
||||
pattern: "*.Example.COM",
|
||||
target: "FOO.example.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard does not match partial suffix",
|
||||
pattern: "*.example.com",
|
||||
target: "notexample.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unicode domain punycode match",
|
||||
pattern: "*.münchen.de",
|
||||
target: "sub.xn--mnchen-3ya.de",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pattern, err := domain.FromString(tt.pattern)
|
||||
require.NoError(t, err)
|
||||
|
||||
target, err := domain.FromString(tt.target)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := MatchDomain(pattern, target)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// newOutboundDialer creates a net.Dialer that clears the socket fwmark.
|
||||
// In kernel TPROXY mode, accepted connections inherit the TPROXY fwmark.
|
||||
// Without clearing it, outbound connections from the proxy would match
|
||||
// the ip rule (fwmark -> local loopback) and loop back to the proxy
|
||||
// instead of reaching the real destination.
|
||||
func newOutboundDialer() net.Dialer {
|
||||
return net.Dialer{
|
||||
Control: func(_, _ string, c syscall.RawConn) error {
|
||||
var sockErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
sockErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, 0)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return sockErr
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package inspect
|
||||
|
||||
import "net"
|
||||
|
||||
// newOutboundDialer returns a plain dialer on non-Linux platforms.
|
||||
// TPROXY is Linux-only, so no fwmark clearing is needed.
|
||||
func newOutboundDialer() net.Dialer {
|
||||
return net.Dialer{}
|
||||
}
|
||||
@@ -1,298 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
envoyStartTimeout = 15 * time.Second
|
||||
envoyHealthInterval = 500 * time.Millisecond
|
||||
envoyStopTimeout = 10 * time.Second
|
||||
envoyDrainTime = 5
|
||||
)
|
||||
|
||||
// envoyManager manages the lifecycle of an envoy sidecar process.
|
||||
type envoyManager struct {
|
||||
log *log.Entry
|
||||
cmd *exec.Cmd
|
||||
configPath string
|
||||
listenPort uint16
|
||||
adminPort uint16
|
||||
cancel context.CancelFunc
|
||||
|
||||
blockPagePath string
|
||||
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// startEnvoy finds the envoy binary, generates config, and spawns the process.
|
||||
// It blocks until envoy reports healthy or the timeout expires.
|
||||
func startEnvoy(ctx context.Context, logger *log.Entry, config Config) (*envoyManager, error) {
|
||||
envCfg := config.Envoy
|
||||
if envCfg == nil {
|
||||
return nil, fmt.Errorf("envoy config is nil")
|
||||
}
|
||||
|
||||
binaryPath, err := findEnvoyBinary(envCfg.BinaryPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find envoy binary: %w", err)
|
||||
}
|
||||
|
||||
// Pick admin port
|
||||
adminPort := envCfg.AdminPort
|
||||
if adminPort == 0 {
|
||||
p, err := findFreePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find free admin port: %w", err)
|
||||
}
|
||||
adminPort = p
|
||||
}
|
||||
|
||||
// Pick listener port
|
||||
listenPort, err := findFreePort()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find free listener port: %w", err)
|
||||
}
|
||||
|
||||
// Use a private temp directory (0700) to prevent local attackers from
|
||||
// replacing the config file between write and envoy read.
|
||||
configDir, err := os.MkdirTemp("", "nb-envoy-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create envoy config directory: %w", err)
|
||||
}
|
||||
|
||||
// Write the block page HTML for envoy's direct_response to reference.
|
||||
blockPagePath := filepath.Join(configDir, "block.html")
|
||||
blockHTML := fmt.Sprintf(blockPageHTML, "blocked domain", "this domain")
|
||||
if err := os.WriteFile(blockPagePath, []byte(blockHTML), 0600); err != nil {
|
||||
return nil, fmt.Errorf("write envoy block page: %w", err)
|
||||
}
|
||||
|
||||
// Generate config with the block page path embedded.
|
||||
bootstrap, err := generateBootstrap(config, listenPort, adminPort, blockPagePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate envoy bootstrap: %w", err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(configDir, "bootstrap.yaml")
|
||||
if err := os.WriteFile(configPath, bootstrap, 0600); err != nil {
|
||||
return nil, fmt.Errorf("write envoy config: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
cmd := exec.CommandContext(ctx, binaryPath,
|
||||
"-c", configPath,
|
||||
"--drain-time-s", fmt.Sprintf("%d", envoyDrainTime),
|
||||
)
|
||||
|
||||
// Pipe envoy output to our logger.
|
||||
cmd.Stdout = &logWriter{entry: logger, level: log.DebugLevel}
|
||||
cmd.Stderr = &logWriter{entry: logger, level: log.WarnLevel}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
cancel()
|
||||
os.Remove(configPath)
|
||||
return nil, fmt.Errorf("start envoy: %w", err)
|
||||
}
|
||||
|
||||
mgr := &envoyManager{
|
||||
log: logger,
|
||||
cmd: cmd,
|
||||
configPath: configPath,
|
||||
listenPort: listenPort,
|
||||
adminPort: adminPort,
|
||||
blockPagePath: blockPagePath,
|
||||
cancel: cancel,
|
||||
running: true,
|
||||
}
|
||||
|
||||
// Wait for envoy to become healthy.
|
||||
if err := mgr.waitHealthy(ctx); err != nil {
|
||||
mgr.Stop()
|
||||
return nil, fmt.Errorf("wait for envoy readiness: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("inspect: envoy started (pid=%d, listen=%d, admin=%d)", cmd.Process.Pid, listenPort, adminPort)
|
||||
|
||||
// Monitor process exit in background.
|
||||
go mgr.monitor()
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// ListenAddr returns the address envoy listens on for forwarded connections.
|
||||
func (m *envoyManager) ListenAddr() netip.AddrPort {
|
||||
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), m.listenPort)
|
||||
}
|
||||
|
||||
// AdminAddr returns the envoy admin API address.
|
||||
func (m *envoyManager) AdminAddr() string {
|
||||
return fmt.Sprintf("127.0.0.1:%d", m.adminPort)
|
||||
}
|
||||
|
||||
// Reload writes a new config and sends SIGHUP to envoy.
|
||||
func (m *envoyManager) Reload(config Config) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return fmt.Errorf("envoy is not running")
|
||||
}
|
||||
|
||||
bootstrap, err := generateBootstrap(config, m.listenPort, m.adminPort, m.blockPagePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate envoy bootstrap: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.configPath, bootstrap, 0600); err != nil {
|
||||
return fmt.Errorf("write envoy config: %w", err)
|
||||
}
|
||||
|
||||
if err := signalReload(m.cmd.Process); err != nil {
|
||||
return fmt.Errorf("signal envoy reload: %w", err)
|
||||
}
|
||||
|
||||
m.log.Debugf("inspect: envoy config reloaded")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Healthy checks the envoy admin API /ready endpoint.
|
||||
func (m *envoyManager) Healthy() bool {
|
||||
resp, err := http.Get(fmt.Sprintf("http://%s/ready", m.AdminAddr()))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// Stop terminates the envoy process and cleans up.
|
||||
func (m *envoyManager) Stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
m.running = false
|
||||
|
||||
m.cancel()
|
||||
|
||||
if m.cmd.Process != nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
m.cmd.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(envoyStopTimeout):
|
||||
m.log.Warnf("inspect: envoy did not exit in %s, killing", envoyStopTimeout)
|
||||
m.cmd.Process.Kill()
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
os.RemoveAll(filepath.Dir(m.configPath))
|
||||
m.log.Infof("inspect: envoy stopped")
|
||||
}
|
||||
|
||||
// waitHealthy polls the admin API until envoy is ready or timeout.
|
||||
func (m *envoyManager) waitHealthy(ctx context.Context) error {
|
||||
deadline := time.After(envoyStartTimeout)
|
||||
ticker := time.NewTicker(envoyHealthInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-deadline:
|
||||
return fmt.Errorf("envoy not ready after %s", envoyStartTimeout)
|
||||
case <-ticker.C:
|
||||
if m.Healthy() {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitor watches for unexpected envoy exits.
|
||||
func (m *envoyManager) monitor() {
|
||||
err := m.cmd.Wait()
|
||||
|
||||
m.mu.Lock()
|
||||
wasRunning := m.running
|
||||
m.running = false
|
||||
m.mu.Unlock()
|
||||
|
||||
if wasRunning {
|
||||
m.log.Errorf("inspect: envoy exited unexpectedly: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// findEnvoyBinary resolves the envoy binary path.
|
||||
func findEnvoyBinary(configPath string) (string, error) {
|
||||
if configPath != "" {
|
||||
if _, err := os.Stat(configPath); err != nil {
|
||||
return "", fmt.Errorf("envoy binary not found at %s: %w", configPath, err)
|
||||
}
|
||||
return configPath, nil
|
||||
}
|
||||
|
||||
path, err := exec.LookPath("envoy")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("envoy not found in PATH: %w", err)
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// findFreePort asks the OS for an available TCP port.
|
||||
func findFreePort() (uint16, error) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
port := uint16(ln.Addr().(*net.TCPAddr).Port)
|
||||
ln.Close()
|
||||
return port, nil
|
||||
}
|
||||
|
||||
// logWriter adapts log.Entry to io.Writer for piping process output.
|
||||
type logWriter struct {
|
||||
entry *log.Entry
|
||||
level log.Level
|
||||
}
|
||||
|
||||
func (w *logWriter) Write(p []byte) (int, error) {
|
||||
msg := strings.TrimRight(string(p), "\n\r")
|
||||
if msg == "" {
|
||||
return len(p), nil
|
||||
}
|
||||
switch w.level {
|
||||
case log.WarnLevel:
|
||||
w.entry.Warn(msg)
|
||||
default:
|
||||
w.entry.Debug(msg)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Ensure logWriter satisfies io.Writer.
|
||||
var _ io.Writer = (*logWriter)(nil)
|
||||
@@ -1,382 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// envoyBootstrapTmpl generates the full envoy bootstrap with rule translation.
|
||||
// TLS rules become per-SNI filter chains; HTTP rules become per-domain virtual hosts.
|
||||
var envoyBootstrapTmpl = template.Must(template.New("bootstrap").Funcs(template.FuncMap{
|
||||
"quote": func(s string) string { return fmt.Sprintf("%q", s) },
|
||||
}).Parse(`node:
|
||||
id: netbird-inspect
|
||||
cluster: netbird
|
||||
admin:
|
||||
address:
|
||||
socket_address:
|
||||
address: 127.0.0.1
|
||||
port_value: {{.AdminPort}}
|
||||
static_resources:
|
||||
listeners:
|
||||
- name: inspect_listener
|
||||
address:
|
||||
socket_address:
|
||||
address: 127.0.0.1
|
||||
port_value: {{.ListenPort}}
|
||||
listener_filters:
|
||||
- name: envoy.filters.listener.proxy_protocol
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol
|
||||
- name: envoy.filters.listener.tls_inspector
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector
|
||||
filter_chains:
|
||||
{{- /* TLS filter chains: per-SNI block/allow + default */ -}}
|
||||
{{- range .TLSChains}}
|
||||
- filter_chain_match:
|
||||
transport_protocol: tls
|
||||
{{- if .ServerNames}}
|
||||
server_names:
|
||||
{{- range .ServerNames}}
|
||||
- {{quote .}}
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
filters:
|
||||
{{$.NetworkFiltersSnippet}} - name: envoy.filters.network.tcp_proxy
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy
|
||||
stat_prefix: {{.StatPrefix}}
|
||||
cluster: original_dst
|
||||
access_log:
|
||||
- name: envoy.access_loggers.stderr
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
|
||||
log_format:
|
||||
text_format: "[%START_TIME%] tcp %DOWNSTREAM_REMOTE_ADDRESS% -> %UPSTREAM_HOST% %RESPONSE_FLAGS% %DURATION%ms\n"
|
||||
{{- end}}
|
||||
{{- /* Plain HTTP filter chain with per-domain virtual hosts */}}
|
||||
- filters:
|
||||
- name: envoy.filters.network.http_connection_manager
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
|
||||
stat_prefix: inspect_http
|
||||
access_log:
|
||||
- name: envoy.access_loggers.stderr
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
|
||||
log_format:
|
||||
text_format: "[%START_TIME%] http %DOWNSTREAM_REMOTE_ADDRESS% %REQ(:AUTHORITY)% %REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %RESPONSE_CODE% %RESPONSE_FLAGS% %DURATION%ms\n"
|
||||
http_filters:
|
||||
{{.HTTPFiltersSnippet}} - name: envoy.filters.http.router
|
||||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
|
||||
route_config:
|
||||
virtual_hosts:
|
||||
{{- range .VirtualHosts}}
|
||||
- name: {{.Name}}
|
||||
domains: [{{.DomainsStr}}]
|
||||
routes:
|
||||
{{- range .Routes}}
|
||||
- match:
|
||||
prefix: "{{if .PathPrefix}}{{.PathPrefix}}{{else}}/{{end}}"
|
||||
{{- if .Block}}
|
||||
direct_response:
|
||||
status: 403
|
||||
body:
|
||||
filename: "{{$.BlockPagePath}}"
|
||||
{{- else}}
|
||||
route:
|
||||
cluster: original_dst
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
{{- end}}
|
||||
clusters:
|
||||
- name: original_dst
|
||||
type: ORIGINAL_DST
|
||||
lb_policy: CLUSTER_PROVIDED
|
||||
connect_timeout: 10s
|
||||
{{.ExtraClusters}}`))
|
||||
|
||||
// tlsChain represents a TLS filter chain entry for the template.
|
||||
// All TLS chains are passthrough (block decisions happen in Go before envoy).
|
||||
type tlsChain struct {
|
||||
// ServerNames restricts this chain to specific SNIs. Empty is catch-all.
|
||||
ServerNames []string
|
||||
StatPrefix string
|
||||
}
|
||||
|
||||
// envoyRoute represents a single route entry within a virtual host.
|
||||
type envoyRoute struct {
|
||||
// PathPrefix for envoy prefix match. Empty means catch-all "/".
|
||||
PathPrefix string
|
||||
Block bool
|
||||
}
|
||||
|
||||
// virtualHost represents an HTTP virtual host entry for the template.
|
||||
type virtualHost struct {
|
||||
Name string
|
||||
// DomainsStr is pre-formatted for the template: "a", "b".
|
||||
DomainsStr string
|
||||
Routes []envoyRoute
|
||||
}
|
||||
|
||||
type bootstrapData struct {
|
||||
AdminPort uint16
|
||||
ListenPort uint16
|
||||
BlockPagePath string
|
||||
TLSChains []tlsChain
|
||||
VirtualHosts []virtualHost
|
||||
HTTPFiltersSnippet string
|
||||
NetworkFiltersSnippet string
|
||||
ExtraClusters string
|
||||
}
|
||||
|
||||
// generateBootstrap produces the envoy bootstrap YAML from the inspect config.
|
||||
// Translates inspection rules into envoy-native per-SNI and per-domain routing.
|
||||
// blockPagePath is the path to the HTML block page file served by direct_response.
|
||||
func generateBootstrap(config Config, listenPort, adminPort uint16, blockPagePath string) ([]byte, error) {
|
||||
data := bootstrapData{
|
||||
AdminPort: adminPort,
|
||||
BlockPagePath: blockPagePath,
|
||||
ListenPort: listenPort,
|
||||
TLSChains: buildTLSChains(config),
|
||||
VirtualHosts: buildVirtualHosts(config),
|
||||
}
|
||||
|
||||
if config.Envoy != nil && config.Envoy.Snippets != nil {
|
||||
s := config.Envoy.Snippets
|
||||
data.HTTPFiltersSnippet = indentSnippet(s.HTTPFilters, 18)
|
||||
data.NetworkFiltersSnippet = indentSnippet(s.NetworkFilters, 12)
|
||||
data.ExtraClusters = indentSnippet(s.Clusters, 4)
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := envoyBootstrapTmpl.Execute(&buf, data); err != nil {
|
||||
return nil, fmt.Errorf("execute bootstrap template: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// buildTLSChains translates inspection rules into envoy TLS filter chains.
|
||||
// Block rules -> per-SNI chain routing to blackhole.
|
||||
// Allow rules (when default=block) -> per-SNI chain routing to original_dst.
|
||||
// Default chain follows DefaultAction.
|
||||
func buildTLSChains(config Config) []tlsChain {
|
||||
// TLS block decisions happen in Go before forwarding to envoy, so we only
|
||||
// generate allow/passthrough chains here. Envoy can't cleanly close a TLS
|
||||
// connection without completing a handshake, so blocked SNIs never reach envoy.
|
||||
var allowed []string
|
||||
|
||||
for _, rule := range config.Rules {
|
||||
if !ruleTouchesProtocol(rule, ProtoHTTPS, ProtoH2) {
|
||||
continue
|
||||
}
|
||||
for _, d := range rule.Domains {
|
||||
sni := d.PunycodeString()
|
||||
if rule.Action == ActionAllow || rule.Action == ActionInspect {
|
||||
allowed = append(allowed, sni)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var chains []tlsChain
|
||||
|
||||
if len(allowed) > 0 && config.DefaultAction == ActionBlock {
|
||||
chains = append(chains, tlsChain{
|
||||
ServerNames: allowed,
|
||||
StatPrefix: "tls_allowed",
|
||||
})
|
||||
}
|
||||
|
||||
// Default catch-all: passthrough (blocked SNIs never arrive here)
|
||||
chains = append(chains, tlsChain{
|
||||
StatPrefix: "tls_default",
|
||||
})
|
||||
|
||||
return chains
|
||||
}
|
||||
|
||||
// buildVirtualHosts translates inspection rules into envoy HTTP virtual hosts.
|
||||
// Groups rules by domain, generates per-path routes within each virtual host.
|
||||
func buildVirtualHosts(config Config) []virtualHost {
|
||||
// Group rules by domain for per-domain virtual hosts.
|
||||
type domainRules struct {
|
||||
domains []string
|
||||
routes []envoyRoute
|
||||
}
|
||||
|
||||
domainRouteMap := make(map[string][]envoyRoute)
|
||||
|
||||
for _, rule := range config.Rules {
|
||||
if !ruleTouchesProtocol(rule, ProtoHTTP, ProtoWebSocket) {
|
||||
continue
|
||||
}
|
||||
isBlock := rule.Action == ActionBlock
|
||||
|
||||
// Rules without domains or paths are handled by the default action.
|
||||
if len(rule.Domains) == 0 && len(rule.Paths) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build routes for this rule's paths
|
||||
var routes []envoyRoute
|
||||
if len(rule.Paths) > 0 {
|
||||
for _, p := range rule.Paths {
|
||||
// Convert our path patterns to envoy prefix match.
|
||||
// Strip trailing * for envoy prefix matching.
|
||||
prefix := strings.TrimSuffix(p, "*")
|
||||
routes = append(routes, envoyRoute{PathPrefix: prefix, Block: isBlock})
|
||||
}
|
||||
} else {
|
||||
routes = append(routes, envoyRoute{Block: isBlock})
|
||||
}
|
||||
|
||||
if len(rule.Domains) > 0 {
|
||||
for _, d := range rule.Domains {
|
||||
host := d.PunycodeString()
|
||||
domainRouteMap[host] = append(domainRouteMap[host], routes...)
|
||||
}
|
||||
} else {
|
||||
// No domain: applies to all, add to default host
|
||||
domainRouteMap["*"] = append(domainRouteMap["*"], routes...)
|
||||
}
|
||||
}
|
||||
|
||||
var hosts []virtualHost
|
||||
idx := 0
|
||||
|
||||
// Per-domain virtual hosts with path routes
|
||||
for domain, routes := range domainRouteMap {
|
||||
if domain == "*" {
|
||||
continue
|
||||
}
|
||||
// Add a catch-all route after path-specific routes.
|
||||
// The catch-all follows the default action.
|
||||
routes = append(routes, envoyRoute{Block: config.DefaultAction == ActionBlock})
|
||||
|
||||
hosts = append(hosts, virtualHost{
|
||||
Name: fmt.Sprintf("domain_%d", idx),
|
||||
DomainsStr: fmt.Sprintf("%q", domain),
|
||||
Routes: routes,
|
||||
})
|
||||
idx++
|
||||
}
|
||||
|
||||
// Default virtual host (catch-all for unmatched domains)
|
||||
defaultRoutes := domainRouteMap["*"]
|
||||
defaultRoutes = append(defaultRoutes, envoyRoute{Block: config.DefaultAction == ActionBlock})
|
||||
hosts = append(hosts, virtualHost{
|
||||
Name: "default",
|
||||
DomainsStr: `"*"`,
|
||||
Routes: defaultRoutes,
|
||||
})
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// ruleTouchesProtocol returns true if the rule's protocol list includes any of the given protocols,
|
||||
// or if the protocol list is empty (matches all).
|
||||
func ruleTouchesProtocol(rule Rule, protos ...ProtoType) bool {
|
||||
if len(rule.Protocols) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, rp := range rule.Protocols {
|
||||
for _, p := range protos {
|
||||
if rp == p {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// indentSnippet prepends each line of the YAML snippet with the given number of spaces.
|
||||
// Returns empty string if snippet is empty.
|
||||
func indentSnippet(snippet string, spaces int) string {
|
||||
if snippet == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := make([]byte, spaces)
|
||||
for i := range prefix {
|
||||
prefix[i] = ' '
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
for i, line := range bytes.Split([]byte(snippet), []byte("\n")) {
|
||||
if i > 0 {
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
if len(line) > 0 {
|
||||
buf.Write(prefix)
|
||||
buf.Write(line)
|
||||
}
|
||||
}
|
||||
buf.WriteByte('\n')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// ValidateSnippets checks that user-provided snippets are safe to inject
|
||||
// into the envoy config. Returns an error describing the first violation found.
|
||||
//
|
||||
// Validation rules:
|
||||
// - Each snippet must be valid YAML (prevents syntax-level injection)
|
||||
// - Snippets must not contain YAML document separators (--- or ...) that could
|
||||
// break out of the indentation context
|
||||
// - Snippets must only contain list items (starting with "- ") at the top level,
|
||||
// matching what envoy expects for filters and clusters
|
||||
func ValidateSnippets(snippets *EnvoySnippets) error {
|
||||
if snippets == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fields := []struct {
|
||||
name string
|
||||
value string
|
||||
}{
|
||||
{"http_filters", snippets.HTTPFilters},
|
||||
{"network_filters", snippets.NetworkFilters},
|
||||
{"clusters", snippets.Clusters},
|
||||
}
|
||||
|
||||
for _, f := range fields {
|
||||
if f.value == "" {
|
||||
continue
|
||||
}
|
||||
if err := validateSnippetYAML(f.name, f.value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSnippetYAML(name, snippet string) error {
|
||||
// Check for YAML document markers that could break template structure.
|
||||
for _, line := range strings.Split(snippet, "\n") {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == "---" || trimmed == "..." {
|
||||
return fmt.Errorf("snippet %q: YAML document separators (--- or ...) are not allowed", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify it's valid YAML by checking it doesn't cause template execution issues.
|
||||
// We can't import yaml.v3 here without adding a dependency, so we do structural checks.
|
||||
|
||||
// Check for null bytes or control characters that could confuse YAML parsers.
|
||||
for i, b := range []byte(snippet) {
|
||||
if b == 0 {
|
||||
return fmt.Errorf("snippet %q: null byte at position %d", name, i)
|
||||
}
|
||||
if b < 0x09 || (b > 0x0D && b < 0x20 && b != 0x1B) {
|
||||
return fmt.Errorf("snippet %q: control character 0x%02x at position %d", name, b, i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// PROXY protocol v2 constants (RFC 7239 / HAProxy spec)
|
||||
var proxyV2Signature = [12]byte{
|
||||
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51,
|
||||
0x55, 0x49, 0x54, 0x0A,
|
||||
}
|
||||
|
||||
const (
|
||||
proxyV2VersionCommand = 0x21 // version 2, PROXY command
|
||||
proxyV2FamilyTCP4 = 0x11 // AF_INET, STREAM
|
||||
proxyV2FamilyTCP6 = 0x21 // AF_INET6, STREAM
|
||||
)
|
||||
|
||||
// forwardToEnvoy forwards a connection to the given envoy sidecar via PROXY protocol v2.
|
||||
// The caller provides the envoy manager snapshot to avoid accessing p.envoy without lock.
|
||||
func (p *Proxy) forwardToEnvoy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo, em *envoyManager) error {
|
||||
envoyAddr := em.ListenAddr()
|
||||
|
||||
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", envoyAddr.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial envoy at %s: %w", envoyAddr, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
p.log.Debugf("close envoy conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := writeProxyV2Header(conn, src.IP, dst); err != nil {
|
||||
return fmt.Errorf("write PROXY v2 header: %w", err)
|
||||
}
|
||||
|
||||
p.log.Tracef("envoy: forwarded %s -> %s via PROXY v2", src.IP, dst)
|
||||
|
||||
return relay(ctx, pconn, conn)
|
||||
}
|
||||
|
||||
// writeProxyV2Header writes a PROXY protocol v2 header to w.
|
||||
// The header encodes the original source IP and the destination address:port.
|
||||
func writeProxyV2Header(w net.Conn, srcIP netip.Addr, dst netip.AddrPort) error {
|
||||
srcIP = srcIP.Unmap()
|
||||
dstIP := dst.Addr().Unmap()
|
||||
|
||||
var (
|
||||
family byte
|
||||
addrs []byte
|
||||
)
|
||||
|
||||
if srcIP.Is4() && dstIP.Is4() {
|
||||
family = proxyV2FamilyTCP4
|
||||
s4 := srcIP.As4()
|
||||
d4 := dstIP.As4()
|
||||
addrs = make([]byte, 12) // 4+4+2+2
|
||||
copy(addrs[0:4], s4[:])
|
||||
copy(addrs[4:8], d4[:])
|
||||
binary.BigEndian.PutUint16(addrs[8:10], 0) // src port unknown
|
||||
binary.BigEndian.PutUint16(addrs[10:12], dst.Port())
|
||||
} else {
|
||||
family = proxyV2FamilyTCP6
|
||||
s16 := srcIP.As16()
|
||||
d16 := dstIP.As16()
|
||||
addrs = make([]byte, 36) // 16+16+2+2
|
||||
copy(addrs[0:16], s16[:])
|
||||
copy(addrs[16:32], d16[:])
|
||||
binary.BigEndian.PutUint16(addrs[32:34], 0) // src port unknown
|
||||
binary.BigEndian.PutUint16(addrs[34:36], dst.Port())
|
||||
}
|
||||
|
||||
// Header: signature(12) + ver_cmd(1) + family(1) + len(2) + addrs
|
||||
header := make([]byte, 16+len(addrs))
|
||||
copy(header[0:12], proxyV2Signature[:])
|
||||
header[12] = proxyV2VersionCommand
|
||||
header[13] = family
|
||||
binary.BigEndian.PutUint16(header[14:16], uint16(len(addrs)))
|
||||
copy(header[16:], addrs)
|
||||
|
||||
_, err := w.Write(header)
|
||||
return err
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// signalReload sends SIGHUP to the envoy process to trigger config reload.
|
||||
func signalReload(p *os.Process) error {
|
||||
return p.Signal(syscall.SIGHUP)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// signalReload is not supported on Windows. Envoy must be restarted.
|
||||
func signalReload(_ *os.Process) error {
|
||||
return fmt.Errorf("envoy config reload via signal not supported on Windows")
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
externalDialTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// handleExternal forwards the connection to an external proxy.
|
||||
// For TLS connections, it uses HTTP CONNECT to tunnel through the proxy.
|
||||
// For HTTP connections, it rewrites the request to use the proxy.
|
||||
func (p *Proxy) handleExternal(ctx context.Context, pconn *peekConn, dst netip.AddrPort) error {
|
||||
p.mu.RLock()
|
||||
proxyURL := p.config.ExternalURL
|
||||
p.mu.RUnlock()
|
||||
|
||||
if proxyURL == nil {
|
||||
return fmt.Errorf("external proxy URL not configured")
|
||||
}
|
||||
|
||||
switch proxyURL.Scheme {
|
||||
case "http", "https":
|
||||
return p.externalHTTPProxy(ctx, pconn, dst, proxyURL)
|
||||
case "socks5":
|
||||
return p.externalSOCKS5(ctx, pconn, dst, proxyURL)
|
||||
default:
|
||||
return fmt.Errorf("unsupported external proxy scheme: %s", proxyURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
// externalHTTPProxy tunnels through an HTTP proxy using CONNECT.
|
||||
func (p *Proxy) externalHTTPProxy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
|
||||
proxyAddr := proxyURL.Host
|
||||
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
|
||||
proxyAddr = net.JoinHostPort(proxyAddr, "8080")
|
||||
}
|
||||
|
||||
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial external proxy %s: %w", proxyAddr, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
p.log.Debugf("close external proxy conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", dst.String(), dst.String())
|
||||
if proxyURL.User != nil {
|
||||
connectReq += "Proxy-Authorization: Basic " + basicAuth(proxyURL.User) + "\r\n"
|
||||
}
|
||||
connectReq += "\r\n"
|
||||
|
||||
if _, err := io.WriteString(proxyConn, connectReq); err != nil {
|
||||
return fmt.Errorf("send CONNECT to proxy: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read CONNECT response: %w", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close CONNECT resp body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
|
||||
}
|
||||
|
||||
return relay(ctx, pconn, proxyConn)
|
||||
}
|
||||
|
||||
// externalSOCKS5 tunnels through a SOCKS5 proxy.
|
||||
func (p *Proxy) externalSOCKS5(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
|
||||
proxyAddr := proxyURL.Host
|
||||
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
|
||||
proxyAddr = net.JoinHostPort(proxyAddr, "1080")
|
||||
}
|
||||
|
||||
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial SOCKS5 proxy %s: %w", proxyAddr, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxyConn.Close(); err != nil {
|
||||
p.log.Debugf("close SOCKS5 proxy conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := socks5Handshake(proxyConn, dst, proxyURL.User); err != nil {
|
||||
return fmt.Errorf("SOCKS5 handshake: %w", err)
|
||||
}
|
||||
|
||||
return relay(ctx, pconn, proxyConn)
|
||||
}
|
||||
|
||||
// socks5Handshake performs the SOCKS5 handshake to connect through the proxy.
|
||||
func socks5Handshake(conn net.Conn, dst netip.AddrPort, userinfo *url.Userinfo) error {
|
||||
needAuth := userinfo != nil
|
||||
|
||||
// Greeting
|
||||
var methods []byte
|
||||
if needAuth {
|
||||
methods = []byte{0x00, 0x02} // no auth, username/password
|
||||
} else {
|
||||
methods = []byte{0x00} // no auth
|
||||
}
|
||||
greeting := append([]byte{0x05, byte(len(methods))}, methods...)
|
||||
if _, err := conn.Write(greeting); err != nil {
|
||||
return fmt.Errorf("send greeting: %w", err)
|
||||
}
|
||||
|
||||
// Server method selection
|
||||
var methodResp [2]byte
|
||||
if _, err := io.ReadFull(conn, methodResp[:]); err != nil {
|
||||
return fmt.Errorf("read method selection: %w", err)
|
||||
}
|
||||
if methodResp[0] != 0x05 {
|
||||
return fmt.Errorf("unexpected SOCKS version: %d", methodResp[0])
|
||||
}
|
||||
|
||||
// Handle authentication if selected
|
||||
if methodResp[1] == 0x02 {
|
||||
if err := socks5Auth(conn, userinfo); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if methodResp[1] != 0x00 {
|
||||
return fmt.Errorf("unsupported SOCKS5 auth method: %d", methodResp[1])
|
||||
}
|
||||
|
||||
// Connection request
|
||||
addr := dst.Addr()
|
||||
var addrBytes []byte
|
||||
if addr.Is4() {
|
||||
a4 := addr.As4()
|
||||
addrBytes = append([]byte{0x01}, a4[:]...) // IPv4
|
||||
} else {
|
||||
a16 := addr.As16()
|
||||
addrBytes = append([]byte{0x04}, a16[:]...) // IPv6
|
||||
}
|
||||
|
||||
port := dst.Port()
|
||||
connectReq := append([]byte{0x05, 0x01, 0x00}, addrBytes...)
|
||||
connectReq = append(connectReq, byte(port>>8), byte(port))
|
||||
|
||||
if _, err := conn.Write(connectReq); err != nil {
|
||||
return fmt.Errorf("send connect request: %w", err)
|
||||
}
|
||||
|
||||
// Read response (minimum 10 bytes for IPv4)
|
||||
var respHeader [4]byte
|
||||
if _, err := io.ReadFull(conn, respHeader[:]); err != nil {
|
||||
return fmt.Errorf("read connect response: %w", err)
|
||||
}
|
||||
if respHeader[1] != 0x00 {
|
||||
return fmt.Errorf("SOCKS5 connect failed: status %d", respHeader[1])
|
||||
}
|
||||
|
||||
// Skip bound address
|
||||
switch respHeader[3] {
|
||||
case 0x01: // IPv4
|
||||
var skip [4 + 2]byte
|
||||
if _, err := io.ReadFull(conn, skip[:]); err != nil {
|
||||
return fmt.Errorf("read SOCKS5 bound IPv4 address: %w", err)
|
||||
}
|
||||
case 0x04: // IPv6
|
||||
var skip [16 + 2]byte
|
||||
if _, err := io.ReadFull(conn, skip[:]); err != nil {
|
||||
return fmt.Errorf("read SOCKS5 bound IPv6 address: %w", err)
|
||||
}
|
||||
case 0x03: // Domain
|
||||
var dLen [1]byte
|
||||
if _, err := io.ReadFull(conn, dLen[:]); err != nil {
|
||||
return fmt.Errorf("read domain length: %w", err)
|
||||
}
|
||||
skip := make([]byte, int(dLen[0])+2)
|
||||
if _, err := io.ReadFull(conn, skip); err != nil {
|
||||
return fmt.Errorf("read SOCKS5 bound domain address: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func socks5Auth(conn net.Conn, userinfo *url.Userinfo) error {
|
||||
if userinfo == nil {
|
||||
return fmt.Errorf("SOCKS5 auth required but no credentials provided")
|
||||
}
|
||||
|
||||
user := userinfo.Username()
|
||||
pass, _ := userinfo.Password()
|
||||
|
||||
// Username/password auth (RFC 1929)
|
||||
auth := []byte{0x01, byte(len(user))}
|
||||
auth = append(auth, []byte(user)...)
|
||||
auth = append(auth, byte(len(pass)))
|
||||
auth = append(auth, []byte(pass)...)
|
||||
|
||||
if _, err := conn.Write(auth); err != nil {
|
||||
return fmt.Errorf("send auth: %w", err)
|
||||
}
|
||||
|
||||
var resp [2]byte
|
||||
if _, err := io.ReadFull(conn, resp[:]); err != nil {
|
||||
return fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
if resp[1] != 0x00 {
|
||||
return fmt.Errorf("SOCKS5 auth failed: status %d", resp[1])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func basicAuth(userinfo *url.Userinfo) string {
|
||||
user := userinfo.Username()
|
||||
pass, _ := userinfo.Password()
|
||||
return base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
|
||||
}
|
||||
@@ -1,532 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
headerUpgrade = "Upgrade"
|
||||
valueWebSocket = "websocket"
|
||||
)
|
||||
|
||||
// inspectHTTP runs the HTTP inspection pipeline on decrypted traffic.
|
||||
// It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy),
|
||||
// and WebSocket upgrade detection.
|
||||
func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error {
|
||||
if proto == "h2" {
|
||||
return p.inspectH2(ctx, client, remote, dst, sni, src)
|
||||
}
|
||||
return p.inspectH1(ctx, client, remote, dst, sni, src)
|
||||
}
|
||||
|
||||
// inspectH1 handles HTTP/1.1 request-response inspection in a loop.
|
||||
func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
|
||||
clientReader := bufio.NewReader(client)
|
||||
remoteReader := bufio.NewReader(remote)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// Set idle timeout between requests to prevent connection hogging.
|
||||
if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return fmt.Errorf("set idle deadline: %w", err)
|
||||
}
|
||||
req, err := http.ReadRequest(clientReader)
|
||||
if err != nil {
|
||||
if isClosedErr(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read HTTP request: %w", err)
|
||||
}
|
||||
if err := client.SetReadDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
// Re-evaluate rules based on Host header if SNI was empty
|
||||
host := hostFromRequest(req, sni)
|
||||
|
||||
// Domain fronting: Host header doesn't match TLS SNI
|
||||
if isDomainFronting(req, sni) {
|
||||
p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString())
|
||||
writeBlockResponse(client, req, host)
|
||||
return ErrBlocked
|
||||
}
|
||||
|
||||
proto := ProtoHTTP
|
||||
if isWebSocketUpgrade(req) {
|
||||
proto = ProtoWebSocket
|
||||
}
|
||||
action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path)
|
||||
if action == ActionBlock {
|
||||
p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
|
||||
writeBlockResponse(client, req, host)
|
||||
return ErrBlocked
|
||||
}
|
||||
p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action)
|
||||
|
||||
// ICAP REQMOD: send request for inspection.
|
||||
// Snapshot ICAP client under lock to avoid use-after-close races.
|
||||
p.mu.RLock()
|
||||
icap := p.icap
|
||||
p.mu.RUnlock()
|
||||
if icap != nil {
|
||||
modified, err := icap.ReqMod(req)
|
||||
if err != nil {
|
||||
p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
|
||||
// Fail-closed: block on ICAP error
|
||||
writeBlockResponse(client, req, host)
|
||||
return fmt.Errorf("ICAP REQMOD: %w", err)
|
||||
}
|
||||
req = modified
|
||||
}
|
||||
|
||||
if isWebSocketUpgrade(req) {
|
||||
return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader)
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(req.Header)
|
||||
|
||||
if err := req.Write(remote); err != nil {
|
||||
return fmt.Errorf("forward request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(remoteReader, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read HTTP response: %w", err)
|
||||
}
|
||||
|
||||
// ICAP RESPMOD: send response for inspection
|
||||
if icap != nil {
|
||||
modified, err := icap.RespMod(req, resp)
|
||||
if err != nil {
|
||||
p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close resp body: %v", err)
|
||||
}
|
||||
writeBlockResponse(client, req, host)
|
||||
return fmt.Errorf("ICAP RESPMOD: %w", err)
|
||||
}
|
||||
resp = modified
|
||||
}
|
||||
|
||||
removeHopByHopHeaders(resp.Header)
|
||||
|
||||
if err := resp.Write(client); err != nil {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
p.log.Debugf("close resp body: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("forward response: %w", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close resp body: %v", err)
|
||||
}
|
||||
|
||||
// Connection: close means we're done
|
||||
if resp.Close || req.Close {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// inspectH2 proxies HTTP/2 traffic using Go's http stack.
|
||||
// Client and remote are already-established TLS connections with h2 negotiated.
|
||||
func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
|
||||
// For h2 MITM inspection, we use a local http.Server reading from the client
|
||||
// connection and an http.Transport writing to the remote connection.
|
||||
//
|
||||
// The transport is configured to use the existing TLS connection to the
|
||||
// real server. The handler inspects each request/response pair.
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return remote, nil
|
||||
},
|
||||
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
|
||||
return remote, nil
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
}
|
||||
|
||||
handler := &h2InspectionHandler{
|
||||
proxy: p,
|
||||
transport: transport,
|
||||
dst: dst,
|
||||
sni: sni,
|
||||
src: src,
|
||||
}
|
||||
|
||||
server := &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
// Serve the single client connection.
|
||||
// ServeConn blocks until the connection is done.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
// http.Server doesn't have a direct ServeConn for h2,
|
||||
// so we use Serve with a single-connection listener.
|
||||
ln := &singleConnListener{conn: client}
|
||||
errCh <- server.Serve(ln)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := server.Close(); err != nil {
|
||||
p.log.Debugf("close h2 server: %v", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
if err == http.ErrServerClosed {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// h2InspectionHandler inspects each HTTP/2 request/response pair.
|
||||
type h2InspectionHandler struct {
|
||||
proxy *Proxy
|
||||
transport http.RoundTripper
|
||||
dst netip.AddrPort
|
||||
sni domain.Domain
|
||||
src SourceInfo
|
||||
}
|
||||
|
||||
func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
host := hostFromRequest(req, h.sni)
|
||||
|
||||
if isDomainFronting(req, h.sni) {
|
||||
h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString())
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
|
||||
action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path)
|
||||
if action == ActionBlock {
|
||||
h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
|
||||
// ICAP REQMOD
|
||||
if h.proxy.icap != nil {
|
||||
modified, err := h.proxy.icap.ReqMod(req)
|
||||
if err != nil {
|
||||
h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
req = modified
|
||||
}
|
||||
|
||||
// Forward to upstream
|
||||
req.URL.Scheme = "https"
|
||||
req.URL.Host = h.sni.PunycodeString()
|
||||
req.RequestURI = ""
|
||||
|
||||
resp, err := h.transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
h.proxy.log.Debugf("close h2 resp body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// ICAP RESPMOD
|
||||
if h.proxy.icap != nil {
|
||||
modified, err := h.proxy.icap.RespMod(req, resp)
|
||||
if err != nil {
|
||||
h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
|
||||
writeBlockPage(w, host)
|
||||
return
|
||||
}
|
||||
resp = modified
|
||||
}
|
||||
|
||||
// Copy response headers and body
|
||||
for k, vals := range resp.Header {
|
||||
for _, v := range vals {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||
h.proxy.log.Debugf("h2 response copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally.
|
||||
func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error {
|
||||
if err := req.Write(remote); err != nil {
|
||||
return fmt.Errorf("forward WebSocket upgrade: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.ReadResponse(remoteReader, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read WebSocket upgrade response: %w", err)
|
||||
}
|
||||
|
||||
if err := resp.Write(client); err != nil {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
p.log.Debugf("close ws resp body: %v", closeErr)
|
||||
}
|
||||
return fmt.Errorf("forward WebSocket upgrade response: %w", err)
|
||||
}
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
p.log.Debugf("close ws resp body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
p.log.Tracef("allow: WebSocket upgrade for %s", req.Host)
|
||||
|
||||
// Relay WebSocket frames bidirectionally.
|
||||
// clientReader/remoteReader may have buffered data.
|
||||
clientConn := mergeReadWriter(clientReader, client)
|
||||
remoteConn := mergeReadWriter(remoteReader, remote)
|
||||
|
||||
return relayRW(ctx, clientConn, remoteConn)
|
||||
}
|
||||
|
||||
// hostFromRequest extracts a domain.Domain from the HTTP request Host header,
|
||||
// falling back to the SNI if Host is empty or an IP.
|
||||
func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
// Strip port if present
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
// If it's an IP address, use the SNI fallback
|
||||
if _, err := netip.ParseAddr(host); err == nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
d, err := domain.FromString(host)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// isDomainFronting detects domain fronting: the Host header doesn't match the
|
||||
// SNI used during the TLS handshake. Only meaningful when SNI is non-empty
|
||||
// (i.e., we're in MITM mode and know the original SNI).
|
||||
func isDomainFronting(req *http.Request, sni domain.Domain) bool {
|
||||
if sni == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
host := hostFromRequest(req, "")
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Host should match SNI or be a subdomain of SNI
|
||||
if host == sni {
|
||||
return false
|
||||
}
|
||||
|
||||
// Allow www.example.com when SNI is example.com
|
||||
sniStr := sni.PunycodeString()
|
||||
hostStr := host.PunycodeString()
|
||||
if strings.HasSuffix(hostStr, "."+sniStr) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket)
|
||||
}
|
||||
|
||||
// writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path).
|
||||
func writeBlockPage(w http.ResponseWriter, host domain.Domain) {
|
||||
hostname := host.PunycodeString()
|
||||
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
io.WriteString(w, body)
|
||||
}
|
||||
|
||||
func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) {
|
||||
hostname := host.PunycodeString()
|
||||
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusForbidden,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
ContentLength: int64(len(body)),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
|
||||
resp.Header.Set("Connection", "close")
|
||||
resp.Header.Set("Cache-Control", "no-store")
|
||||
_ = resp.Write(w)
|
||||
}
|
||||
|
||||
// blockPageHTML is the self-contained HTML block page.
|
||||
// Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain.
|
||||
const blockPageHTML = `<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>Blocked - %s</title>
|
||||
<style>
|
||||
*{margin:0;padding:0;box-sizing:border-box}
|
||||
body{background:#181a1d;color:#d1d5db;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;min-height:100vh;display:flex;align-items:center;justify-content:center}
|
||||
.c{text-align:center;max-width:460px;padding:2rem}
|
||||
.shield{width:56px;height:56px;margin:0 auto 1.5rem;border-radius:16px;background:#2b2f33;display:flex;align-items:center;justify-content:center}
|
||||
.shield svg{width:28px;height:28px;color:#f68330}
|
||||
.code{font-size:.8rem;font-weight:500;color:#f68330;font-family:ui-monospace,monospace;letter-spacing:.05em;margin-bottom:.5rem}
|
||||
h1{font-size:1.5rem;font-weight:600;color:#f4f4f5;margin-bottom:.5rem}
|
||||
p{font-size:.95rem;line-height:1.5;color:#9ca3af;margin-bottom:1.75rem}
|
||||
.domain{display:inline-block;background:#25282d;border:1px solid #32363d;border-radius:6px;padding:.15rem .5rem;font-family:ui-monospace,monospace;font-size:.85rem;color:#d1d5db}
|
||||
.footer{font-size:.7rem;color:#6b7280;margin-top:2rem;letter-spacing:.03em}
|
||||
.footer a{color:#6b7280;text-decoration:none}
|
||||
.footer a:hover{color:#9ca3af}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="c">
|
||||
<div class="shield"><svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m0-10.036A11.959 11.959 0 0 1 3.598 6 11.99 11.99 0 0 0 3 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751A11.96 11.96 0 0 0 12 3.714Z"/></svg></div>
|
||||
<div class="code">403 BLOCKED</div>
|
||||
<h1>Access Denied</h1>
|
||||
<p>This connection to <span class="domain">%s</span> has been blocked by your organization's network policy.</p>
|
||||
<div class="footer">Protected by <a href="https://netbird.io" target="_blank" rel="noopener">NetBird</a></div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// singleConnListener is a net.Listener that yields a single connection.
|
||||
type singleConnListener struct {
|
||||
conn net.Conn
|
||||
once sync.Once
|
||||
ch chan struct{}
|
||||
}
|
||||
|
||||
func (l *singleConnListener) Accept() (net.Conn, error) {
|
||||
var accepted bool
|
||||
l.once.Do(func() {
|
||||
l.ch = make(chan struct{})
|
||||
accepted = true
|
||||
})
|
||||
if accepted {
|
||||
return l.conn, nil
|
||||
}
|
||||
// Block until Close
|
||||
<-l.ch
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
|
||||
func (l *singleConnListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
l.ch = make(chan struct{})
|
||||
})
|
||||
select {
|
||||
case <-l.ch:
|
||||
default:
|
||||
close(l.ch)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *singleConnListener) Addr() net.Addr {
|
||||
return l.conn.LocalAddr()
|
||||
}
|
||||
|
||||
type readWriter struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
|
||||
return &readWriter{Reader: r, Writer: w}
|
||||
}
|
||||
|
||||
// relayRW copies data bidirectionally between two ReadWriters.
|
||||
func relayRW(ctx context.Context, a, b io.ReadWriter) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(b, a)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(a, b)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
var firstErr error
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil && firstErr == nil {
|
||||
if !isClosedErr(err) {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// hopByHopHeaders are HTTP/1.1 headers that apply to a single connection
|
||||
// and must not be forwarded by a proxy (RFC 7230, Section 6.1).
|
||||
var hopByHopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"TE",
|
||||
"Trailers",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
// removeHopByHopHeaders strips hop-by-hop headers from h.
|
||||
// Also removes headers listed in the Connection header value.
|
||||
func removeHopByHopHeaders(h http.Header) {
|
||||
// First, remove any headers named in the Connection header
|
||||
for _, connHeader := range h["Connection"] {
|
||||
for _, name := range strings.Split(connHeader, ",") {
|
||||
h.Del(strings.TrimSpace(name))
|
||||
}
|
||||
}
|
||||
|
||||
for _, name := range hopByHopHeaders {
|
||||
h.Del(name)
|
||||
}
|
||||
}
|
||||
@@ -1,479 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
icapVersion = "ICAP/1.0"
|
||||
icapDefaultPort = "1344"
|
||||
icapConnTimeout = 30 * time.Second
|
||||
icapRWTimeout = 60 * time.Second
|
||||
icapMaxPoolSize = 8
|
||||
icapIdleTimeout = 60 * time.Second
|
||||
icapMaxRespSize = 4 * 1024 * 1024 // 4 MB
|
||||
)
|
||||
|
||||
// ICAPClient implements an ICAP (RFC 3507) client with persistent connection pooling.
|
||||
type ICAPClient struct {
|
||||
reqModURL *url.URL
|
||||
respModURL *url.URL
|
||||
pool chan *icapConn
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
maxPool int
|
||||
}
|
||||
|
||||
type icapConn struct {
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
lastUse time.Time
|
||||
}
|
||||
|
||||
// NewICAPClient creates an ICAP client. Either or both URLs may be nil
|
||||
// to disable that mode.
|
||||
func NewICAPClient(logger *log.Entry, cfg *ICAPConfig) *ICAPClient {
|
||||
maxPool := cfg.MaxConnections
|
||||
if maxPool <= 0 {
|
||||
maxPool = icapMaxPoolSize
|
||||
}
|
||||
|
||||
return &ICAPClient{
|
||||
reqModURL: cfg.ReqModURL,
|
||||
respModURL: cfg.RespModURL,
|
||||
pool: make(chan *icapConn, maxPool),
|
||||
log: logger,
|
||||
maxPool: maxPool,
|
||||
}
|
||||
}
|
||||
|
||||
// ReqMod sends an HTTP request to the ICAP REQMOD service for inspection.
|
||||
// Returns the (possibly modified) request, or the original if ICAP returns 204.
|
||||
// Returns nil, nil if REQMOD is not configured.
|
||||
func (c *ICAPClient) ReqMod(req *http.Request) (*http.Request, error) {
|
||||
if c.reqModURL == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
var reqBuf bytes.Buffer
|
||||
if err := req.Write(&reqBuf); err != nil {
|
||||
return nil, fmt.Errorf("serialize request: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := c.send("REQMOD", c.reqModURL, reqBuf.Bytes(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if respBody == nil {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
modified, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(respBody)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP modified request: %w", err)
|
||||
}
|
||||
return modified, nil
|
||||
}
|
||||
|
||||
// RespMod sends an HTTP response to the ICAP RESPMOD service for inspection.
|
||||
// Returns the (possibly modified) response, or the original if ICAP returns 204.
|
||||
// Returns nil, nil if RESPMOD is not configured.
|
||||
func (c *ICAPClient) RespMod(req *http.Request, resp *http.Response) (*http.Response, error) {
|
||||
if c.respModURL == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var reqBuf bytes.Buffer
|
||||
if err := req.Write(&reqBuf); err != nil {
|
||||
return nil, fmt.Errorf("serialize request: %w", err)
|
||||
}
|
||||
|
||||
var respBuf bytes.Buffer
|
||||
if err := resp.Write(&respBuf); err != nil {
|
||||
return nil, fmt.Errorf("serialize response: %w", err)
|
||||
}
|
||||
|
||||
respBody, err := c.send("RESPMOD", c.respModURL, reqBuf.Bytes(), respBuf.Bytes())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if respBody == nil {
|
||||
// 204 No Content: ICAP server didn't modify the response.
|
||||
// Reconstruct from the buffered copy since resp.Body was consumed by Write.
|
||||
reconstructed, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBuf.Bytes())), req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reconstruct response after ICAP 204: %w", err)
|
||||
}
|
||||
return reconstructed, nil
|
||||
}
|
||||
|
||||
modified, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBody)), req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP modified response: %w", err)
|
||||
}
|
||||
return modified, nil
|
||||
}
|
||||
|
||||
// Close drains and closes all pooled connections.
|
||||
func (c *ICAPClient) Close() {
|
||||
close(c.pool)
|
||||
for ic := range c.pool {
|
||||
if err := ic.conn.Close(); err != nil {
|
||||
c.log.Debugf("close ICAP connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// send executes an ICAP request and returns the encapsulated body from the response.
|
||||
// Returns nil body for 204 No Content (no modification).
|
||||
// Retries once on stale pooled connection (EOF on read).
|
||||
func (c *ICAPClient) send(method string, serviceURL *url.URL, reqData, respData []byte) ([]byte, error) {
|
||||
statusCode, headers, body, err := c.trySend(method, serviceURL, reqData, respData)
|
||||
if err != nil && isStaleConnErr(err) {
|
||||
// Retry once with a fresh connection (stale pool entry).
|
||||
c.log.Debugf("ICAP %s: retrying after stale connection: %v", method, err)
|
||||
statusCode, headers, body, err = c.trySend(method, serviceURL, reqData, respData)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 204:
|
||||
return nil, nil
|
||||
case 200:
|
||||
return body, nil
|
||||
default:
|
||||
c.log.Debugf("ICAP %s returned status %d, headers: %v", method, statusCode, headers)
|
||||
return nil, fmt.Errorf("ICAP %s: status %d", method, statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ICAPClient) trySend(method string, serviceURL *url.URL, reqData, respData []byte) (int, textproto.MIMEHeader, []byte, error) {
|
||||
ic, err := c.getConn(serviceURL)
|
||||
if err != nil {
|
||||
return 0, nil, nil, fmt.Errorf("get ICAP connection: %w", err)
|
||||
}
|
||||
|
||||
if err := c.writeRequest(ic, method, serviceURL, reqData, respData); err != nil {
|
||||
if closeErr := ic.conn.Close(); closeErr != nil {
|
||||
c.log.Debugf("close ICAP conn after write error: %v", closeErr)
|
||||
}
|
||||
return 0, nil, nil, fmt.Errorf("write ICAP %s: %w", method, err)
|
||||
}
|
||||
|
||||
statusCode, headers, body, err := c.readResponse(ic)
|
||||
if err != nil {
|
||||
if closeErr := ic.conn.Close(); closeErr != nil {
|
||||
c.log.Debugf("close ICAP conn after read error: %v", closeErr)
|
||||
}
|
||||
return 0, nil, nil, fmt.Errorf("read ICAP response: %w", err)
|
||||
}
|
||||
|
||||
c.putConn(ic)
|
||||
return statusCode, headers, body, nil
|
||||
}
|
||||
|
||||
func isStaleConnErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
s := err.Error()
|
||||
return strings.Contains(s, "EOF") || strings.Contains(s, "broken pipe") || strings.Contains(s, "connection reset")
|
||||
}
|
||||
|
||||
func (c *ICAPClient) writeRequest(ic *icapConn, method string, serviceURL *url.URL, reqData, respData []byte) error {
|
||||
if err := ic.conn.SetWriteDeadline(time.Now().Add(icapRWTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
|
||||
// For RESPMOD, split the serialized HTTP response into headers and body.
|
||||
// The body must be sent chunked per RFC 3507.
|
||||
var respHdr, respBody []byte
|
||||
if respData != nil {
|
||||
if idx := bytes.Index(respData, []byte("\r\n\r\n")); idx >= 0 {
|
||||
respHdr = respData[:idx+4] // include the \r\n\r\n separator
|
||||
respBody = respData[idx+4:]
|
||||
} else {
|
||||
respHdr = respData
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Request line
|
||||
fmt.Fprintf(&buf, "%s %s %s\r\n", method, serviceURL.String(), icapVersion)
|
||||
|
||||
// Headers
|
||||
host := serviceURL.Host
|
||||
fmt.Fprintf(&buf, "Host: %s\r\n", host)
|
||||
fmt.Fprintf(&buf, "Connection: keep-alive\r\n")
|
||||
fmt.Fprintf(&buf, "Allow: 204\r\n")
|
||||
|
||||
// Build Encapsulated header
|
||||
offset := 0
|
||||
var encapParts []string
|
||||
if reqData != nil {
|
||||
encapParts = append(encapParts, fmt.Sprintf("req-hdr=%d", offset))
|
||||
offset += len(reqData)
|
||||
}
|
||||
if respHdr != nil {
|
||||
encapParts = append(encapParts, fmt.Sprintf("res-hdr=%d", offset))
|
||||
offset += len(respHdr)
|
||||
}
|
||||
if len(respBody) > 0 {
|
||||
encapParts = append(encapParts, fmt.Sprintf("res-body=%d", offset))
|
||||
} else {
|
||||
encapParts = append(encapParts, fmt.Sprintf("null-body=%d", offset))
|
||||
}
|
||||
fmt.Fprintf(&buf, "Encapsulated: %s\r\n", strings.Join(encapParts, ", "))
|
||||
fmt.Fprintf(&buf, "\r\n")
|
||||
|
||||
// Encapsulated sections
|
||||
if reqData != nil {
|
||||
buf.Write(reqData)
|
||||
}
|
||||
if respHdr != nil {
|
||||
buf.Write(respHdr)
|
||||
}
|
||||
// Body in chunked encoding (only when there is an actual body section).
|
||||
// Per RFC 3507 Section 4.4.1, null-body must not include any entity data.
|
||||
if len(respBody) > 0 {
|
||||
fmt.Fprintf(&buf, "%x\r\n", len(respBody))
|
||||
buf.Write(respBody)
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString("0\r\n\r\n")
|
||||
}
|
||||
|
||||
_, err := ic.conn.Write(buf.Bytes())
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ICAPClient) readResponse(ic *icapConn) (int, textproto.MIMEHeader, []byte, error) {
|
||||
if err := ic.conn.SetReadDeadline(time.Now().Add(icapRWTimeout)); err != nil {
|
||||
return 0, nil, nil, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
|
||||
tp := textproto.NewReader(ic.reader)
|
||||
|
||||
// Status line: "ICAP/1.0 200 OK"
|
||||
statusLine, err := tp.ReadLine()
|
||||
if err != nil {
|
||||
return 0, nil, nil, fmt.Errorf("read status line: %w", err)
|
||||
}
|
||||
|
||||
statusCode, err := parseICAPStatus(statusLine)
|
||||
if err != nil {
|
||||
return 0, nil, nil, err
|
||||
}
|
||||
|
||||
// Headers
|
||||
headers, err := tp.ReadMIMEHeader()
|
||||
if err != nil {
|
||||
return statusCode, nil, nil, fmt.Errorf("read ICAP headers: %w", err)
|
||||
}
|
||||
|
||||
if statusCode == 204 {
|
||||
return statusCode, headers, nil, nil
|
||||
}
|
||||
|
||||
// Read encapsulated body based on Encapsulated header
|
||||
body, err := c.readEncapsulatedBody(ic.reader, headers)
|
||||
if err != nil {
|
||||
return statusCode, headers, nil, fmt.Errorf("read encapsulated body: %w", err)
|
||||
}
|
||||
|
||||
return statusCode, headers, body, nil
|
||||
}
|
||||
|
||||
func (c *ICAPClient) readEncapsulatedBody(r *bufio.Reader, headers textproto.MIMEHeader) ([]byte, error) {
|
||||
encap := headers.Get("Encapsulated")
|
||||
if encap == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Find the body offset from the Encapsulated header.
|
||||
// The last section with a non-zero offset is the body.
|
||||
// Read everything from the reader as the encapsulated content.
|
||||
var totalSize int
|
||||
parts := strings.Split(encap, ",")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx < 0 {
|
||||
continue
|
||||
}
|
||||
offset, err := strconv.Atoi(strings.TrimSpace(part[eqIdx+1:]))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if offset > totalSize {
|
||||
totalSize = offset
|
||||
}
|
||||
}
|
||||
|
||||
// Read all available encapsulated data (headers + body)
|
||||
// The body section uses chunked encoding per RFC 3507
|
||||
var buf bytes.Buffer
|
||||
if totalSize > 0 {
|
||||
// Read the header sections (everything before the body offset)
|
||||
headerBytes := make([]byte, totalSize)
|
||||
if _, err := io.ReadFull(r, headerBytes); err != nil {
|
||||
return nil, fmt.Errorf("read encapsulated headers: %w", err)
|
||||
}
|
||||
buf.Write(headerBytes)
|
||||
}
|
||||
|
||||
// Read chunked body
|
||||
chunked := newChunkedReader(r)
|
||||
body, err := io.ReadAll(io.LimitReader(chunked, icapMaxRespSize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read chunked body: %w", err)
|
||||
}
|
||||
buf.Write(body)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (c *ICAPClient) getConn(serviceURL *url.URL) (*icapConn, error) {
|
||||
// Try to get a pooled connection
|
||||
for {
|
||||
select {
|
||||
case ic := <-c.pool:
|
||||
if time.Since(ic.lastUse) > icapIdleTimeout {
|
||||
if err := ic.conn.Close(); err != nil {
|
||||
c.log.Debugf("close idle ICAP connection: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
return ic, nil
|
||||
default:
|
||||
return c.dialConn(serviceURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ICAPClient) putConn(ic *icapConn) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
ic.lastUse = time.Now()
|
||||
select {
|
||||
case c.pool <- ic:
|
||||
default:
|
||||
// Pool full, close connection.
|
||||
if err := ic.conn.Close(); err != nil {
|
||||
c.log.Debugf("close excess ICAP connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ICAPClient) dialConn(serviceURL *url.URL) (*icapConn, error) {
|
||||
host := serviceURL.Host
|
||||
if _, _, err := net.SplitHostPort(host); err != nil {
|
||||
host = net.JoinHostPort(host, icapDefaultPort)
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("tcp", host, icapConnTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial ICAP %s: %w", host, err)
|
||||
}
|
||||
|
||||
return &icapConn{
|
||||
conn: conn,
|
||||
reader: bufio.NewReader(conn),
|
||||
lastUse: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseICAPStatus(line string) (int, error) {
|
||||
// "ICAP/1.0 200 OK"
|
||||
parts := strings.SplitN(line, " ", 3)
|
||||
if len(parts) < 2 {
|
||||
return 0, fmt.Errorf("malformed ICAP status line: %q", line)
|
||||
}
|
||||
code, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse ICAP status code %q: %w", parts[1], err)
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// chunkedReader reads ICAP chunked encoding (same as HTTP chunked, terminated by "0\r\n\r\n").
|
||||
type chunkedReader struct {
|
||||
r *bufio.Reader
|
||||
remaining int
|
||||
done bool
|
||||
}
|
||||
|
||||
func newChunkedReader(r *bufio.Reader) *chunkedReader {
|
||||
return &chunkedReader{r: r}
|
||||
}
|
||||
|
||||
func (cr *chunkedReader) Read(p []byte) (int, error) {
|
||||
if cr.done {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if cr.remaining == 0 {
|
||||
// Read chunk size line
|
||||
line, err := cr.r.ReadString('\n')
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
// Strip any chunk extensions
|
||||
if idx := strings.Index(line, ";"); idx >= 0 {
|
||||
line = line[:idx]
|
||||
}
|
||||
|
||||
size, err := strconv.ParseInt(line, 16, 64)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse chunk size %q: %w", line, err)
|
||||
}
|
||||
|
||||
if size == 0 {
|
||||
cr.done = true
|
||||
// Consume trailing \r\n
|
||||
_, _ = cr.r.ReadString('\n')
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if size < 0 || size > icapMaxRespSize {
|
||||
return 0, fmt.Errorf("chunk size %d out of range (max %d)", size, icapMaxRespSize)
|
||||
}
|
||||
|
||||
cr.remaining = int(size)
|
||||
}
|
||||
|
||||
toRead := len(p)
|
||||
if toRead > cr.remaining {
|
||||
toRead = cr.remaining
|
||||
}
|
||||
|
||||
n, err := cr.r.Read(p[:toRead])
|
||||
cr.remaining -= n
|
||||
|
||||
if cr.remaining == 0 {
|
||||
// Consume chunk-terminating \r\n
|
||||
_, _ = cr.r.ReadString('\n')
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// newTPROXYListener is not supported on non-Linux platforms.
|
||||
func newTPROXYListener(_ *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
|
||||
return nil, fmt.Errorf("TPROXY listener not supported on this platform (requested %s)", addr)
|
||||
}
|
||||
|
||||
// getOriginalDst is not supported on non-Linux platforms.
|
||||
func getOriginalDst(_ net.Conn) (netip.AddrPort, error) {
|
||||
return netip.AddrPort{}, fmt.Errorf("SO_ORIGINAL_DST not supported on this platform")
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// newTPROXYListener creates a TCP listener for the transparent proxy.
|
||||
// After nftables REDIRECT, accepted connections have LocalAddr = WG_IP:proxy_port.
|
||||
// The original destination is retrieved via getsockopt(SO_ORIGINAL_DST).
|
||||
func newTPROXYListener(logger *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
|
||||
ln, err := net.Listen("tcp", addr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
|
||||
logger.Infof("inspect: listener started on %s", ln.Addr())
|
||||
return ln, nil
|
||||
}
|
||||
|
||||
// getOriginalDst reads the original destination from conntrack via SO_ORIGINAL_DST.
|
||||
// This is set by the kernel when the connection was REDIRECT'd/DNAT'd.
|
||||
// Tries IPv4 first, then falls back to IPv6 (IP6T_SO_ORIGINAL_DST).
|
||||
func getOriginalDst(conn net.Conn) (netip.AddrPort, error) {
|
||||
tc, ok := conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return netip.AddrPort{}, fmt.Errorf("not a TCPConn")
|
||||
}
|
||||
|
||||
raw, err := tc.SyscallConn()
|
||||
if err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("get syscall conn: %w", err)
|
||||
}
|
||||
|
||||
var origDst netip.AddrPort
|
||||
var sockErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
// Try IPv4 first (SO_ORIGINAL_DST = 80)
|
||||
var sa4 unix.RawSockaddrInet4
|
||||
sa4Len := uint32(unsafe.Sizeof(sa4))
|
||||
_, _, errno := unix.Syscall6(
|
||||
unix.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
unix.SOL_IP,
|
||||
80, // SO_ORIGINAL_DST
|
||||
uintptr(unsafe.Pointer(&sa4)),
|
||||
uintptr(unsafe.Pointer(&sa4Len)),
|
||||
0,
|
||||
)
|
||||
if errno == 0 {
|
||||
addr := netip.AddrFrom4(sa4.Addr)
|
||||
port := uint16(sa4.Port>>8) | uint16(sa4.Port<<8)
|
||||
origDst = netip.AddrPortFrom(addr.Unmap(), port)
|
||||
return
|
||||
}
|
||||
|
||||
// Fall back to IPv6 (IP6T_SO_ORIGINAL_DST = 80 on SOL_IPV6)
|
||||
var sa6 unix.RawSockaddrInet6
|
||||
sa6Len := uint32(unsafe.Sizeof(sa6))
|
||||
_, _, errno = unix.Syscall6(
|
||||
unix.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
unix.SOL_IPV6,
|
||||
80, // IP6T_SO_ORIGINAL_DST
|
||||
uintptr(unsafe.Pointer(&sa6)),
|
||||
uintptr(unsafe.Pointer(&sa6Len)),
|
||||
0,
|
||||
)
|
||||
if errno != 0 {
|
||||
sockErr = fmt.Errorf("getsockopt SO_ORIGINAL_DST (v4 and v6): %w", errno)
|
||||
return
|
||||
}
|
||||
addr := netip.AddrFrom16(sa6.Addr)
|
||||
port := uint16(sa6.Port>>8) | uint16(sa6.Port<<8)
|
||||
origDst = netip.AddrPortFrom(addr.Unmap(), port)
|
||||
}); err != nil {
|
||||
return netip.AddrPort{}, fmt.Errorf("control raw conn: %w", err)
|
||||
}
|
||||
if sockErr != nil {
|
||||
return netip.AddrPort{}, sockErr
|
||||
}
|
||||
|
||||
return origDst, nil
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
mrand "math/rand/v2"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// certCacheSize is the maximum number of cached leaf certificates.
|
||||
certCacheSize = 1024
|
||||
// certTTL is how long generated certificates remain valid.
|
||||
certTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
// certCache is a bounded LRU cache for generated TLS certificates.
|
||||
type certCache struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*certEntry
|
||||
// order tracks LRU eviction, most recent at end.
|
||||
order []string
|
||||
maxSize int
|
||||
}
|
||||
|
||||
type certEntry struct {
|
||||
cert *tls.Certificate
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newCertCache(maxSize int) *certCache {
|
||||
return &certCache{
|
||||
entries: make(map[string]*certEntry, maxSize),
|
||||
order: make([]string, 0, maxSize),
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certCache) get(hostname string) (*tls.Certificate, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
entry, ok := c.entries[hostname]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
c.removeLocked(hostname)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Move to end (most recently used)
|
||||
c.touchLocked(hostname)
|
||||
return entry.cert, true
|
||||
}
|
||||
|
||||
func (c *certCache) put(hostname string, cert *tls.Certificate) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Jitter the TTL by +/- 20% to prevent thundering herd on expiry.
|
||||
jitter := time.Duration(float64(certTTL) * (0.8 + 0.4*mrand.Float64()))
|
||||
|
||||
if _, exists := c.entries[hostname]; exists {
|
||||
c.entries[hostname] = &certEntry{
|
||||
cert: cert,
|
||||
expiresAt: time.Now().Add(jitter),
|
||||
}
|
||||
c.touchLocked(hostname)
|
||||
return
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity
|
||||
for len(c.entries) >= c.maxSize && len(c.order) > 0 {
|
||||
c.removeLocked(c.order[0])
|
||||
}
|
||||
|
||||
c.entries[hostname] = &certEntry{
|
||||
cert: cert,
|
||||
expiresAt: time.Now().Add(jitter),
|
||||
}
|
||||
c.order = append(c.order, hostname)
|
||||
}
|
||||
|
||||
func (c *certCache) touchLocked(hostname string) {
|
||||
for i, h := range c.order {
|
||||
if h == hostname {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
c.order = append(c.order, hostname)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *certCache) removeLocked(hostname string) {
|
||||
delete(c.entries, hostname)
|
||||
for i, h := range c.order {
|
||||
if h == hostname {
|
||||
c.order = append(c.order[:i], c.order[i+1:]...)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CertProvider generates TLS certificates on the fly, signed by a CA.
|
||||
// Generated certificates are cached in an LRU cache.
|
||||
type CertProvider struct {
|
||||
ca *x509.Certificate
|
||||
caKey crypto.PrivateKey
|
||||
cache *certCache
|
||||
}
|
||||
|
||||
// NewCertProvider creates a certificate provider using the given CA.
|
||||
func NewCertProvider(ca *x509.Certificate, caKey crypto.PrivateKey) *CertProvider {
|
||||
return &CertProvider{
|
||||
ca: ca,
|
||||
caKey: caKey,
|
||||
cache: newCertCache(certCacheSize),
|
||||
}
|
||||
}
|
||||
|
||||
// GetCertificate returns a TLS certificate for the given hostname,
|
||||
// generating and caching one if necessary.
|
||||
func (p *CertProvider) GetCertificate(hostname string) (*tls.Certificate, error) {
|
||||
if cert, ok := p.cache.get(hostname); ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
cert, err := p.generateCert(hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate cert for %s: %w", hostname, err)
|
||||
}
|
||||
|
||||
p.cache.put(hostname, cert)
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// GetTLSConfig returns a tls.Config that dynamically provides certificates
|
||||
// for any hostname using the MITM CA.
|
||||
func (p *CertProvider) GetTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return p.GetCertificate(hello.ServerName)
|
||||
},
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *CertProvider) generateCert(hostname string) (*tls.Certificate, error) {
|
||||
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{
|
||||
CommonName: hostname,
|
||||
},
|
||||
NotBefore: now.Add(-5 * time.Minute),
|
||||
NotAfter: now.Add(certTTL),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
DNSNames: []string{hostname},
|
||||
}
|
||||
|
||||
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate leaf key: %w", err)
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, p.ca, &leafKey.PublicKey, p.caKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign leaf certificate: %w", err)
|
||||
}
|
||||
|
||||
leafCert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse generated certificate: %w", err)
|
||||
}
|
||||
|
||||
return &tls.Certificate{
|
||||
Certificate: [][]byte{certDER, p.ca.Raw},
|
||||
PrivateKey: leafKey,
|
||||
Leaf: leafCert,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
|
||||
t.Helper()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: "Test CA",
|
||||
},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
require.NoError(t, err)
|
||||
|
||||
return cert, key
|
||||
}
|
||||
|
||||
func TestCertProvider_GetCertificate(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
cert, err := provider.GetCertificate("example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
|
||||
// Verify the leaf certificate
|
||||
assert.Equal(t, "example.com", cert.Leaf.Subject.CommonName)
|
||||
assert.Contains(t, cert.Leaf.DNSNames, "example.com")
|
||||
|
||||
// Verify chain: leaf + CA
|
||||
assert.Len(t, cert.Certificate, 2)
|
||||
|
||||
// Verify leaf is signed by our CA
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(ca)
|
||||
_, err = cert.Leaf.Verify(x509.VerifyOptions{
|
||||
Roots: pool,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCertProvider_CachesResults(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
cert1, err := provider.GetCertificate("cached.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
cert2, err := provider.GetCertificate("cached.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Same pointer = cached
|
||||
assert.Equal(t, cert1, cert2)
|
||||
}
|
||||
|
||||
func TestCertProvider_DifferentHostsDifferentCerts(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
cert1, err := provider.GetCertificate("a.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
cert2, err := provider.GetCertificate("b.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, cert1.Leaf.SerialNumber, cert2.Leaf.SerialNumber)
|
||||
}
|
||||
|
||||
func TestCertProvider_TLSConfigHandshake(t *testing.T) {
|
||||
ca, caKey := generateTestCA(t)
|
||||
provider := NewCertProvider(ca, caKey)
|
||||
|
||||
tlsConfig := provider.GetTLSConfig()
|
||||
require.NotNil(t, tlsConfig)
|
||||
require.NotNil(t, tlsConfig.GetCertificate)
|
||||
|
||||
// Simulate a ClientHelloInfo
|
||||
hello := &tls.ClientHelloInfo{
|
||||
ServerName: "handshake.example.com",
|
||||
}
|
||||
|
||||
cert, err := tlsConfig.GetCertificate(hello)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "handshake.example.com", cert.Leaf.Subject.CommonName)
|
||||
}
|
||||
|
||||
func TestCertCache_Eviction(t *testing.T) {
|
||||
cache := newCertCache(3)
|
||||
|
||||
for i := range 5 {
|
||||
hostname := string(rune('a'+i)) + ".example.com"
|
||||
cache.put(hostname, &tls.Certificate{})
|
||||
}
|
||||
|
||||
// Only 3 should remain (c, d, e - the most recent)
|
||||
assert.Len(t, cache.entries, 3)
|
||||
|
||||
_, ok := cache.get("a.example.com")
|
||||
assert.False(t, ok, "oldest entry should be evicted")
|
||||
|
||||
_, ok = cache.get("b.example.com")
|
||||
assert.False(t, ok, "second oldest should be evicted")
|
||||
|
||||
_, ok = cache.get("e.example.com")
|
||||
assert.True(t, ok, "newest entry should exist")
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// peekConn wraps a net.Conn with a buffer that allows reading ahead
|
||||
// without consuming data. Subsequent Read calls return the buffered
|
||||
// bytes first, then read from the underlying connection.
|
||||
type peekConn struct {
|
||||
net.Conn
|
||||
buf bytes.Buffer
|
||||
// peeked holds the raw bytes that were peeked, available for replay.
|
||||
peeked []byte
|
||||
}
|
||||
|
||||
// newPeekConn wraps conn for peek-ahead reading.
|
||||
func newPeekConn(conn net.Conn) *peekConn {
|
||||
return &peekConn{Conn: conn}
|
||||
}
|
||||
|
||||
// Peek reads exactly n bytes from the connection without consuming them.
|
||||
// The peeked bytes are replayed on subsequent Read calls.
|
||||
// Peek may only be called once; calling it again returns an error.
|
||||
func (c *peekConn) Peek(n int) ([]byte, error) {
|
||||
if c.peeked != nil {
|
||||
return nil, fmt.Errorf("peek already called")
|
||||
}
|
||||
|
||||
buf := make([]byte, n)
|
||||
if _, err := io.ReadFull(c.Conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("peek %d bytes: %w", n, err)
|
||||
}
|
||||
|
||||
c.peeked = buf
|
||||
c.buf.Write(buf)
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// PeekAll reads up to n bytes, returning whatever is available.
|
||||
// Unlike Peek, it does not require exactly n bytes.
|
||||
func (c *peekConn) PeekAll(n int) ([]byte, error) {
|
||||
if c.peeked != nil {
|
||||
return nil, fmt.Errorf("peek already called")
|
||||
}
|
||||
|
||||
buf := make([]byte, n)
|
||||
nr, err := c.Conn.Read(buf)
|
||||
if nr > 0 {
|
||||
c.peeked = buf[:nr]
|
||||
c.buf.Write(c.peeked)
|
||||
}
|
||||
if err != nil && nr == 0 {
|
||||
return nil, fmt.Errorf("peek: %w", err)
|
||||
}
|
||||
|
||||
return c.peeked, nil
|
||||
}
|
||||
|
||||
// PeekMore extends the peeked buffer to at least n total bytes.
|
||||
// The buffer is reset and refilled with the extended data.
|
||||
// The returned slice is the internal peeked buffer; callers must not
|
||||
// retain references from prior Peek/PeekMore calls after calling this.
|
||||
func (c *peekConn) PeekMore(n int) ([]byte, error) {
|
||||
if len(c.peeked) >= n {
|
||||
return c.peeked[:n], nil
|
||||
}
|
||||
|
||||
remaining := n - len(c.peeked)
|
||||
extra := make([]byte, remaining)
|
||||
if _, err := io.ReadFull(c.Conn, extra); err != nil {
|
||||
return nil, fmt.Errorf("peek more %d bytes: %w", remaining, err)
|
||||
}
|
||||
|
||||
// Pre-allocate to avoid reallocation detaching previously returned slices.
|
||||
combined := make([]byte, 0, n)
|
||||
combined = append(combined, c.peeked...)
|
||||
combined = append(combined, extra...)
|
||||
c.peeked = combined
|
||||
c.buf.Reset()
|
||||
c.buf.Write(c.peeked)
|
||||
|
||||
return c.peeked, nil
|
||||
}
|
||||
|
||||
// Peeked returns the bytes that were peeked so far, or nil if Peek hasn't been called.
|
||||
func (c *peekConn) Peeked() []byte {
|
||||
return c.peeked
|
||||
}
|
||||
|
||||
// Read returns buffered peek data first, then reads from the underlying connection.
|
||||
func (c *peekConn) Read(p []byte) (int, error) {
|
||||
if c.buf.Len() > 0 {
|
||||
return c.buf.Read(p)
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
// reader returns an io.Reader that replays buffered bytes then reads from conn.
|
||||
func (c *peekConn) reader() io.Reader {
|
||||
if c.buf.Len() > 0 {
|
||||
return io.MultiReader(&c.buf, c.Conn)
|
||||
}
|
||||
return c.Conn
|
||||
}
|
||||
@@ -1,482 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ErrBlocked is returned when a connection is denied by proxy policy.
|
||||
var ErrBlocked = errors.New("connection blocked by proxy policy")
|
||||
|
||||
const (
|
||||
// headerReadTimeout is the deadline for reading the initial protocol header.
|
||||
// Prevents slow loris attacks where a client opens a connection but sends data slowly.
|
||||
headerReadTimeout = 10 * time.Second
|
||||
|
||||
// idleTimeout is the deadline for idle connections between HTTP requests.
|
||||
idleTimeout = 120 * time.Second
|
||||
)
|
||||
|
||||
// Proxy is the inspection engine for traffic passing through a NetBird
|
||||
// routing peer. It handles protocol detection, rule evaluation, MITM TLS
|
||||
// decryption, ICAP delegation, and external proxy forwarding.
|
||||
type Proxy struct {
|
||||
config Config
|
||||
rules *RuleEngine
|
||||
certs *CertProvider
|
||||
icap *ICAPClient
|
||||
// envoy is nil unless mode is ModeEnvoy.
|
||||
envoy *envoyManager
|
||||
// dialer is the outbound dialer (with SO_MARK cleared on Linux).
|
||||
dialer net.Dialer
|
||||
log *log.Entry
|
||||
// wgNetwork is the WG overlay prefix; dial targets inside it are blocked.
|
||||
wgNetwork netip.Prefix
|
||||
// localIPs reports the routing peer's own IPs; dial targets are blocked.
|
||||
localIPs LocalIPChecker
|
||||
// listener is the TPROXY/REDIRECT listener for kernel mode.
|
||||
listener net.Listener
|
||||
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// LocalIPChecker reports whether an IP belongs to the local machine.
|
||||
type LocalIPChecker interface {
|
||||
IsLocalIP(netip.Addr) bool
|
||||
}
|
||||
|
||||
// New creates a transparent proxy with the given configuration.
|
||||
func New(ctx context.Context, logger *log.Entry, config Config) (*Proxy, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
p := &Proxy{
|
||||
config: config,
|
||||
rules: NewRuleEngine(logger, config.DefaultAction),
|
||||
dialer: newOutboundDialer(),
|
||||
log: logger,
|
||||
wgNetwork: config.WGNetwork,
|
||||
localIPs: config.LocalIPChecker,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.rules.UpdateRules(config.Rules, config.DefaultAction)
|
||||
|
||||
// Initialize MITM certificate provider
|
||||
if config.TLS != nil {
|
||||
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
|
||||
}
|
||||
|
||||
// Initialize ICAP client
|
||||
if config.ICAP != nil {
|
||||
p.icap = NewICAPClient(logger, config.ICAP)
|
||||
}
|
||||
|
||||
// Start envoy sidecar if configured
|
||||
if config.Mode == ModeEnvoy {
|
||||
envoyLog := logger.WithField("sidecar", "envoy")
|
||||
em, err := startEnvoy(ctx, envoyLog, config)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("start envoy sidecar: %w", err)
|
||||
}
|
||||
p.envoy = em
|
||||
}
|
||||
|
||||
// Start TPROXY listener for kernel mode
|
||||
if config.ListenAddr.IsValid() {
|
||||
ln, err := newTPROXYListener(logger, config.ListenAddr, netip.Prefix{})
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, fmt.Errorf("start TPROXY listener on %s: %w", config.ListenAddr, err)
|
||||
}
|
||||
p.listener = ln
|
||||
go p.acceptLoop(ln)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// HandleTCP is the entry point for TCP connections from the userspace forwarder.
|
||||
// It determines the protocol (TLS or plaintext HTTP), evaluates rules,
|
||||
// and either blocks, passes through, inspects, or forwards to an external proxy.
|
||||
func (p *Proxy) HandleTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) error {
|
||||
defer func() {
|
||||
if err := clientConn.Close(); err != nil {
|
||||
p.log.Debugf("close client conn: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
p.mu.RLock()
|
||||
mode := p.config.Mode
|
||||
p.mu.RUnlock()
|
||||
|
||||
if mode == ModeExternal {
|
||||
pconn := newPeekConn(clientConn)
|
||||
return p.handleExternal(ctx, pconn, dst)
|
||||
}
|
||||
|
||||
// Envoy and builtin modes both peek the protocol header for rule evaluation.
|
||||
// Envoy mode forwards non-blocked traffic to envoy; builtin mode handles all locally.
|
||||
// TLS blocks are handled by Go (instant close) since envoy can't cleanly RST a TLS connection.
|
||||
|
||||
// Built-in and envoy mode: peek 5 bytes (TLS record header size) to determine protocol.
|
||||
// Set a read deadline to prevent slow loris attacks.
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
|
||||
return fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
pconn := newPeekConn(clientConn)
|
||||
header, err := pconn.Peek(5)
|
||||
if err != nil {
|
||||
return fmt.Errorf("peek protocol header: %w", err)
|
||||
}
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
if isTLSHandshake(header[0]) {
|
||||
return p.handleTLS(ctx, pconn, dst, src)
|
||||
}
|
||||
|
||||
if isHTTPMethod(header) {
|
||||
return p.handlePlainHTTP(ctx, pconn, dst, src)
|
||||
}
|
||||
|
||||
// Not TLS and not HTTP: evaluate rules with ProtoOther.
|
||||
// If no rule explicitly allows "other", this falls through to the default action.
|
||||
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
|
||||
if action == ActionAllow {
|
||||
remote, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial for passthrough: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remote.Close(); err != nil {
|
||||
p.log.Debugf("close remote conn: %v", err)
|
||||
}
|
||||
}()
|
||||
return relay(ctx, pconn, remote)
|
||||
}
|
||||
|
||||
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
|
||||
return ErrBlocked
|
||||
}
|
||||
|
||||
// InspectTCP evaluates rules for a TCP connection and returns the result.
|
||||
// Unlike HandleTCP, it can return early for allow decisions, letting the caller
|
||||
// handle the relay (USP forwarder passthrough optimization).
|
||||
//
|
||||
// When InspectResult.PassthroughConn is non-nil, ownership transfers to the caller:
|
||||
// the caller must close the connection and relay traffic. The engine does not close it.
|
||||
//
|
||||
// When PassthroughConn is nil, the engine handled everything internally
|
||||
// (block, inspect/MITM, or plain HTTP inspection) and closed the connection.
|
||||
func (p *Proxy) InspectTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
|
||||
p.mu.RLock()
|
||||
mode := p.config.Mode
|
||||
envoy := p.envoy
|
||||
p.mu.RUnlock()
|
||||
|
||||
// External mode: handle internally, engine owns the connection.
|
||||
if mode == ModeExternal {
|
||||
defer func() {
|
||||
if err := clientConn.Close(); err != nil {
|
||||
p.log.Debugf("close client conn: %v", err)
|
||||
}
|
||||
}()
|
||||
pconn := newPeekConn(clientConn)
|
||||
err := p.handleExternal(ctx, pconn, dst)
|
||||
return InspectResult{Action: ActionAllow}, err
|
||||
}
|
||||
|
||||
// Peek protocol header.
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
|
||||
clientConn.Close()
|
||||
return InspectResult{}, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
pconn := newPeekConn(clientConn)
|
||||
header, err := pconn.Peek(5)
|
||||
if err != nil {
|
||||
clientConn.Close()
|
||||
return InspectResult{}, fmt.Errorf("peek protocol header: %w", err)
|
||||
}
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
clientConn.Close()
|
||||
return InspectResult{}, fmt.Errorf("clear read deadline: %w", err)
|
||||
}
|
||||
|
||||
// TLS: may return passthrough for allow.
|
||||
if isTLSHandshake(header[0]) {
|
||||
result, err := p.inspectTLS(ctx, pconn, dst, src)
|
||||
if err != nil && result.PassthroughConn == nil {
|
||||
clientConn.Close()
|
||||
return result, err
|
||||
}
|
||||
// Envoy mode: forward allowed TLS to envoy instead of returning passthrough.
|
||||
if result.PassthroughConn != nil && envoy != nil {
|
||||
defer clientConn.Close()
|
||||
envoyErr := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
return InspectResult{Action: ActionAllow}, envoyErr
|
||||
}
|
||||
return result, err
|
||||
}
|
||||
|
||||
// Plain HTTP: in envoy mode, forward to envoy for L7 processing.
|
||||
// In builtin mode, inspect per-request locally.
|
||||
if isHTTPMethod(header) {
|
||||
defer func() {
|
||||
if err := clientConn.Close(); err != nil {
|
||||
p.log.Debugf("close client conn: %v", err)
|
||||
}
|
||||
}()
|
||||
if envoy != nil {
|
||||
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
return InspectResult{Action: ActionAllow}, err
|
||||
}
|
||||
err := p.handlePlainHTTP(ctx, pconn, dst, src)
|
||||
return InspectResult{Action: ActionInspect}, err
|
||||
}
|
||||
|
||||
// Other protocol: evaluate rules.
|
||||
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
|
||||
if action == ActionAllow {
|
||||
// Envoy mode: forward to envoy.
|
||||
if envoy != nil {
|
||||
defer clientConn.Close()
|
||||
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
return InspectResult{Action: ActionAllow}, err
|
||||
}
|
||||
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
|
||||
}
|
||||
|
||||
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
|
||||
clientConn.Close()
|
||||
return InspectResult{Action: ActionBlock}, ErrBlocked
|
||||
}
|
||||
|
||||
// HandleUDPPacket inspects a UDP packet for QUIC Initial packets.
|
||||
// Returns the action to take: ActionAllow to continue normal forwarding,
|
||||
// ActionBlock to drop the packet.
|
||||
// Non-QUIC packets always return ActionAllow.
|
||||
func (p *Proxy) HandleUDPPacket(data []byte, dst netip.AddrPort, src SourceInfo) Action {
|
||||
if len(data) < 5 {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
// Check for QUIC Long Header
|
||||
if data[0]&0x80 == 0 {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
sni, err := ExtractQUICSNI(data)
|
||||
if err != nil {
|
||||
// Can't parse QUIC, allow through (could be non-QUIC UDP)
|
||||
p.log.Tracef("QUIC SNI extraction failed for %s: %v", dst, err)
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
if sni == "" {
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
action := p.rules.Evaluate(src.IP, sni, dst.Addr(), dst.Port(), ProtoH3, "")
|
||||
|
||||
if action == ActionBlock {
|
||||
p.log.Debugf("block: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
return ActionBlock
|
||||
}
|
||||
|
||||
// QUIC can't be MITMed, treat Inspect as Allow
|
||||
if action == ActionInspect {
|
||||
p.log.Debugf("allow: QUIC to %s (SNI=%s), MITM not supported for QUIC", dst, sni.PunycodeString())
|
||||
} else {
|
||||
p.log.Tracef("allow: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
}
|
||||
|
||||
return ActionAllow
|
||||
}
|
||||
|
||||
// handlePlainHTTP handles plaintext HTTP connections.
|
||||
func (p *Proxy) handlePlainHTTP(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
|
||||
remote, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", dst, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remote.Close(); err != nil {
|
||||
p.log.Debugf("close remote for %s: %v", dst, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// For plaintext HTTP, always inspect (we can see the traffic)
|
||||
return p.inspectHTTP(ctx, pconn, remote, dst, "", src, "http/1.1")
|
||||
}
|
||||
|
||||
// UpdateConfig replaces the inspection engine configuration at runtime.
|
||||
func (p *Proxy) UpdateConfig(config Config) {
|
||||
p.log.Debugf("config update: mode=%s rules=%d default=%s has_tls=%v has_icap=%v",
|
||||
config.Mode, len(config.Rules), config.DefaultAction, config.TLS != nil, config.ICAP != nil)
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
p.config = config
|
||||
p.rules.UpdateRules(config.Rules, config.DefaultAction)
|
||||
|
||||
// Update MITM provider
|
||||
if config.TLS != nil {
|
||||
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
|
||||
} else {
|
||||
p.certs = nil
|
||||
}
|
||||
|
||||
// Swap ICAP client under lock, close the old one outside to avoid blocking.
|
||||
var oldICAP *ICAPClient
|
||||
if config.ICAP != nil {
|
||||
oldICAP = p.icap
|
||||
p.icap = NewICAPClient(p.log, config.ICAP)
|
||||
} else {
|
||||
oldICAP = p.icap
|
||||
p.icap = nil
|
||||
}
|
||||
|
||||
// If switching away from envoy mode, clear and stop the old envoy.
|
||||
var oldEnvoy *envoyManager
|
||||
if config.Mode != ModeEnvoy && p.envoy != nil {
|
||||
oldEnvoy = p.envoy
|
||||
p.envoy = nil
|
||||
}
|
||||
|
||||
envoy := p.envoy
|
||||
|
||||
p.mu.Unlock()
|
||||
|
||||
if oldICAP != nil {
|
||||
oldICAP.Close()
|
||||
}
|
||||
|
||||
if oldEnvoy != nil {
|
||||
oldEnvoy.Stop()
|
||||
}
|
||||
|
||||
// Reload envoy config if still in envoy mode.
|
||||
if envoy != nil && config.Mode == ModeEnvoy {
|
||||
if err := envoy.Reload(config); err != nil {
|
||||
p.log.Errorf("inspect: envoy config reload: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mode returns the current proxy operating mode.
|
||||
func (p *Proxy) Mode() ProxyMode {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.config.Mode
|
||||
}
|
||||
|
||||
// ListenPort returns the port to use for kernel-mode nftables REDIRECT.
|
||||
// For builtin mode: the TPROXY listener port.
|
||||
// For envoy mode: the envoy listener port (nftables redirects directly to envoy).
|
||||
// Returns 0 if no listener is active.
|
||||
func (p *Proxy) ListenPort() uint16 {
|
||||
p.mu.RLock()
|
||||
envoy := p.envoy
|
||||
p.mu.RUnlock()
|
||||
|
||||
if envoy != nil {
|
||||
return envoy.listenPort
|
||||
}
|
||||
if p.listener == nil {
|
||||
return 0
|
||||
}
|
||||
tcpAddr, ok := p.listener.Addr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return uint16(tcpAddr.Port)
|
||||
}
|
||||
|
||||
// Close shuts down the proxy and releases resources.
|
||||
func (p *Proxy) Close() error {
|
||||
p.cancel()
|
||||
|
||||
p.mu.Lock()
|
||||
envoy := p.envoy
|
||||
p.envoy = nil
|
||||
icap := p.icap
|
||||
p.icap = nil
|
||||
p.mu.Unlock()
|
||||
|
||||
if envoy != nil {
|
||||
envoy.Stop()
|
||||
}
|
||||
|
||||
if p.listener != nil {
|
||||
if err := p.listener.Close(); err != nil {
|
||||
p.log.Debugf("close TPROXY listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if icap != nil {
|
||||
icap.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop accepts connections from the redirected listener (kernel mode).
|
||||
// Connections arrive via nftables REDIRECT; original destination is read from conntrack.
|
||||
func (p *Proxy) acceptLoop(ln net.Listener) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
p.log.Debugf("accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func() {
|
||||
// Read original destination from conntrack (SO_ORIGINAL_DST).
|
||||
// nftables REDIRECT changes dst to the local WG IP:proxy_port,
|
||||
// but conntrack preserves the real destination.
|
||||
dstAddr, err := getOriginalDst(conn)
|
||||
if err != nil {
|
||||
p.log.Debugf("get original dst: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
p.log.Debugf("close conn: %v", closeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
p.log.Tracef("accepted: %s -> %s (original dst %s)",
|
||||
conn.RemoteAddr(), conn.LocalAddr(), dstAddr)
|
||||
|
||||
srcAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
p.log.Debugf("parse source: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
p.log.Debugf("close conn: %v", closeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
src := SourceInfo{
|
||||
IP: srcAddr.Addr().Unmap(),
|
||||
}
|
||||
|
||||
if err := p.HandleTCP(p.ctx, conn, dstAddr, src); err != nil && !errors.Is(err, ErrBlocked) {
|
||||
p.log.Debugf("connection to %s: %v", dstAddr, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -1,388 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// QUIC version constants
|
||||
const (
|
||||
quicV1Version uint32 = 0x00000001
|
||||
quicV2Version uint32 = 0x6b3343cf
|
||||
)
|
||||
|
||||
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
|
||||
var quicV1Salt = []byte{
|
||||
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
|
||||
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
|
||||
0xcc, 0xbb, 0x7f, 0x0a,
|
||||
}
|
||||
|
||||
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
|
||||
var quicV2Salt = []byte{
|
||||
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
|
||||
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
|
||||
0xf9, 0xbd, 0x2e, 0xd9,
|
||||
}
|
||||
|
||||
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
|
||||
// The Initial packet's encryption uses well-known keys derived from the
|
||||
// Destination Connection ID, so any observer can decrypt it (by design).
|
||||
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
|
||||
if len(data) < 5 {
|
||||
return "", fmt.Errorf("packet too short")
|
||||
}
|
||||
|
||||
// Check for QUIC Long Header (form bit set)
|
||||
if data[0]&0x80 == 0 {
|
||||
return "", fmt.Errorf("not a QUIC long header packet")
|
||||
}
|
||||
|
||||
// Version
|
||||
version := binary.BigEndian.Uint32(data[1:5])
|
||||
|
||||
var salt []byte
|
||||
var initialLabel, keyLabel, ivLabel, hpLabel string
|
||||
|
||||
switch version {
|
||||
case quicV1Version:
|
||||
salt = quicV1Salt
|
||||
initialLabel = "client in"
|
||||
keyLabel = "quic key"
|
||||
ivLabel = "quic iv"
|
||||
hpLabel = "quic hp"
|
||||
case quicV2Version:
|
||||
salt = quicV2Salt
|
||||
initialLabel = "client in"
|
||||
keyLabel = "quicv2 key"
|
||||
ivLabel = "quicv2 iv"
|
||||
hpLabel = "quicv2 hp"
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
|
||||
}
|
||||
|
||||
// Parse Long Header
|
||||
if len(data) < 6 {
|
||||
return "", fmt.Errorf("packet too short for DCID length")
|
||||
}
|
||||
dcidLen := int(data[5])
|
||||
if len(data) < 6+dcidLen+1 {
|
||||
return "", fmt.Errorf("packet too short for DCID")
|
||||
}
|
||||
dcid := data[6 : 6+dcidLen]
|
||||
|
||||
scidLenOff := 6 + dcidLen
|
||||
scidLen := int(data[scidLenOff])
|
||||
tokenLenOff := scidLenOff + 1 + scidLen
|
||||
|
||||
if tokenLenOff >= len(data) {
|
||||
return "", fmt.Errorf("packet too short for token length")
|
||||
}
|
||||
|
||||
// Token length is a variable-length integer
|
||||
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read token length: %w", err)
|
||||
}
|
||||
|
||||
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
|
||||
if payloadLenOff >= len(data) {
|
||||
return "", fmt.Errorf("packet too short for payload length")
|
||||
}
|
||||
|
||||
// Payload length is a variable-length integer
|
||||
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read payload length: %w", err)
|
||||
}
|
||||
|
||||
pnOffset := payloadLenOff + payloadLenSize
|
||||
if pnOffset+4 > len(data) {
|
||||
return "", fmt.Errorf("packet too short for packet number")
|
||||
}
|
||||
|
||||
// Derive initial keys
|
||||
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("derive initial keys: %w", err)
|
||||
}
|
||||
|
||||
// Remove header protection
|
||||
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
|
||||
if sampleOffset+16 > len(data) {
|
||||
return "", fmt.Errorf("packet too short for HP sample")
|
||||
}
|
||||
sample := data[sampleOffset : sampleOffset+16]
|
||||
|
||||
hpBlock, err := aes.NewCipher(clientHP)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create HP cipher: %w", err)
|
||||
}
|
||||
|
||||
mask := make([]byte, 16)
|
||||
hpBlock.Encrypt(mask, sample)
|
||||
|
||||
// Unmask header byte
|
||||
header := make([]byte, len(data))
|
||||
copy(header, data)
|
||||
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
|
||||
|
||||
// Determine packet number length
|
||||
pnLen := int(header[0]&0x03) + 1
|
||||
|
||||
// Unmask packet number
|
||||
for i := 0; i < pnLen; i++ {
|
||||
header[pnOffset+i] ^= mask[1+i]
|
||||
}
|
||||
|
||||
// Reconstruct packet number
|
||||
var pn uint32
|
||||
for i := 0; i < pnLen; i++ {
|
||||
pn = (pn << 8) | uint32(header[pnOffset+i])
|
||||
}
|
||||
|
||||
// Build nonce
|
||||
nonce := make([]byte, len(clientIV))
|
||||
copy(nonce, clientIV)
|
||||
for i := 0; i < 4; i++ {
|
||||
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
|
||||
}
|
||||
|
||||
// Decrypt payload
|
||||
block, err := aes.NewCipher(clientKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create AES cipher: %w", err)
|
||||
}
|
||||
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create AEAD: %w", err)
|
||||
}
|
||||
|
||||
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
|
||||
aad := header[:pnOffset+pnLen]
|
||||
|
||||
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse CRYPTO frames to extract ClientHello
|
||||
clientHello, err := extractCryptoFrames(plaintext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
|
||||
}
|
||||
|
||||
info, err := parseHelloBody(clientHello)
|
||||
return info.SNI, err
|
||||
}
|
||||
|
||||
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
|
||||
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
|
||||
// initial_secret = HKDF-Extract(salt, DCID)
|
||||
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
|
||||
|
||||
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
|
||||
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
|
||||
}
|
||||
|
||||
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
|
||||
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
|
||||
}
|
||||
|
||||
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
|
||||
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
|
||||
}
|
||||
|
||||
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
|
||||
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
|
||||
}
|
||||
|
||||
return key, iv, hp, nil
|
||||
}
|
||||
|
||||
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
|
||||
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
|
||||
// HkdfLabel = struct {
|
||||
// uint16 length;
|
||||
// opaque label<7..255> = "tls13 " + Label;
|
||||
// opaque context<0..255> = Context;
|
||||
// }
|
||||
fullLabel := "tls13 " + label
|
||||
|
||||
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
|
||||
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
|
||||
hkdfLabel[2] = byte(len(fullLabel))
|
||||
copy(hkdfLabel[3:], fullLabel)
|
||||
hkdfLabel[3+len(fullLabel)] = byte(len(context))
|
||||
if len(context) > 0 {
|
||||
copy(hkdfLabel[4+len(fullLabel):], context)
|
||||
}
|
||||
|
||||
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
|
||||
out := make([]byte, length)
|
||||
if _, err := io.ReadFull(expander, out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
|
||||
const maxCryptoFrameSize = 64 * 1024
|
||||
|
||||
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
|
||||
func extractCryptoFrames(frames []byte) ([]byte, error) {
|
||||
var result []byte
|
||||
pos := 0
|
||||
|
||||
for pos < len(frames) {
|
||||
frameType := frames[pos]
|
||||
|
||||
switch {
|
||||
case frameType == 0x00:
|
||||
// PADDING frame
|
||||
pos++
|
||||
|
||||
case frameType == 0x06:
|
||||
// CRYPTO frame
|
||||
pos++
|
||||
|
||||
offset, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read crypto offset: %w", err)
|
||||
}
|
||||
pos += n
|
||||
_ = offset // We assume ordered, offset 0 for Initial
|
||||
|
||||
dataLen, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read crypto data length: %w", err)
|
||||
}
|
||||
pos += n
|
||||
|
||||
end := pos + int(dataLen)
|
||||
if end > len(frames) {
|
||||
return nil, fmt.Errorf("CRYPTO frame data truncated")
|
||||
}
|
||||
|
||||
result = append(result, frames[pos:end]...)
|
||||
if len(result) > maxCryptoFrameSize {
|
||||
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
|
||||
}
|
||||
pos = end
|
||||
|
||||
case frameType == 0x01:
|
||||
// PING frame
|
||||
pos++
|
||||
|
||||
case frameType == 0x02 || frameType == 0x03:
|
||||
// ACK frame - skip
|
||||
pos++
|
||||
// Largest Acknowledged
|
||||
_, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// ACK Delay
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK delay: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// ACK Range Count
|
||||
rangeCount, n, err := readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK range count: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// First ACK Range
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read first ACK range: %w", err)
|
||||
}
|
||||
pos += n
|
||||
// Additional ranges
|
||||
for i := uint64(0); i < rangeCount; i++ {
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK gap: %w", err)
|
||||
}
|
||||
pos += n
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ACK range: %w", err)
|
||||
}
|
||||
pos += n
|
||||
}
|
||||
// ECN counts for type 0x03
|
||||
if frameType == 0x03 {
|
||||
for range 3 {
|
||||
_, n, err = readVarInt(frames[pos:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read ECN count: %w", err)
|
||||
}
|
||||
pos += n
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown frame type, stop parsing
|
||||
if len(result) > 0 {
|
||||
return result, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
|
||||
}
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, fmt.Errorf("no CRYPTO frames found")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readVarInt reads a QUIC variable-length integer.
|
||||
// Returns (value, bytes consumed, error).
|
||||
func readVarInt(data []byte) (uint64, int, error) {
|
||||
if len(data) == 0 {
|
||||
return 0, 0, fmt.Errorf("empty data for varint")
|
||||
}
|
||||
|
||||
prefix := data[0] >> 6
|
||||
length := 1 << prefix
|
||||
|
||||
if len(data) < length {
|
||||
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
|
||||
}
|
||||
|
||||
var val uint64
|
||||
switch length {
|
||||
case 1:
|
||||
val = uint64(data[0] & 0x3f)
|
||||
case 2:
|
||||
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
|
||||
case 4:
|
||||
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
|
||||
case 8:
|
||||
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
|
||||
}
|
||||
|
||||
return val, length, nil
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReadVarInt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
want uint64
|
||||
n int
|
||||
}{
|
||||
{
|
||||
name: "1 byte value",
|
||||
data: []byte{0x25},
|
||||
want: 37,
|
||||
n: 1,
|
||||
},
|
||||
{
|
||||
name: "2 byte value",
|
||||
data: []byte{0x7b, 0xbd},
|
||||
want: 15293,
|
||||
n: 2,
|
||||
},
|
||||
{
|
||||
name: "4 byte value",
|
||||
data: []byte{0x9d, 0x7f, 0x3e, 0x7d},
|
||||
want: 494878333,
|
||||
n: 4,
|
||||
},
|
||||
{
|
||||
name: "zero",
|
||||
data: []byte{0x00},
|
||||
want: 0,
|
||||
n: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, n, err := readVarInt(tt.data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.want, val)
|
||||
assert.Equal(t, tt.n, n)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadVarInt_Empty(t *testing.T) {
|
||||
_, _, err := readVarInt(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReadVarInt_Truncated(t *testing.T) {
|
||||
// 2-byte prefix but only 1 byte
|
||||
_, _, err := readVarInt([]byte{0x40})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExtractQUICSNI_NotLongHeader(t *testing.T) {
|
||||
// Short header packet (form bit not set)
|
||||
data := make([]byte, 100)
|
||||
data[0] = 0x40 // short header
|
||||
|
||||
_, err := ExtractQUICSNI(data)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not a QUIC long header")
|
||||
}
|
||||
|
||||
func TestExtractQUICSNI_UnsupportedVersion(t *testing.T) {
|
||||
data := make([]byte, 100)
|
||||
data[0] = 0xC0 // long header
|
||||
// Version 0xdeadbeef
|
||||
data[1] = 0xde
|
||||
data[2] = 0xad
|
||||
data[3] = 0xbe
|
||||
data[4] = 0xef
|
||||
|
||||
_, err := ExtractQUICSNI(data)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported QUIC version")
|
||||
}
|
||||
|
||||
func TestExtractQUICSNI_TooShort(t *testing.T) {
|
||||
_, err := ExtractQUICSNI([]byte{0xC0, 0x00})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestHkdfExpandLabel(t *testing.T) {
|
||||
// Smoke test: ensure it returns the right length and doesn't error
|
||||
secret := make([]byte, 32)
|
||||
result, err := hkdfExpandLabel(secret, "quic key", nil, 16)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, result, 16)
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// RuleEngine evaluates proxy rules against connection metadata.
|
||||
// It is safe for concurrent use.
|
||||
type RuleEngine struct {
|
||||
mu sync.RWMutex
|
||||
rules []Rule
|
||||
// defaultAction applies when no rule matches.
|
||||
defaultAction Action
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
// NewRuleEngine creates a rule engine with the given default action.
|
||||
func NewRuleEngine(logger *log.Entry, defaultAction Action) *RuleEngine {
|
||||
return &RuleEngine{
|
||||
defaultAction: defaultAction,
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateRules replaces the rule set and default action. Rules are sorted by priority.
|
||||
func (e *RuleEngine) UpdateRules(rules []Rule, defaultAction Action) {
|
||||
sorted := make([]Rule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
|
||||
e.mu.Lock()
|
||||
e.rules = sorted
|
||||
e.defaultAction = defaultAction
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// EvalResult holds the outcome of a rule evaluation.
|
||||
type EvalResult struct {
|
||||
Action Action
|
||||
RuleID id.RuleID
|
||||
}
|
||||
|
||||
// Evaluate determines the action for a connection based on the rule set.
|
||||
// Pass empty path for connection-level evaluation (TLS/SNI), non-empty for request-level (HTTP).
|
||||
func (e *RuleEngine) Evaluate(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) Action {
|
||||
r := e.EvaluateWithResult(src, dstDomain, dstAddr, dstPort, proto, path)
|
||||
return r.Action
|
||||
}
|
||||
|
||||
// EvaluateWithResult is like Evaluate but also returns the matched rule ID.
|
||||
func (e *RuleEngine) EvaluateWithResult(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) EvalResult {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
for i := range e.rules {
|
||||
rule := &e.rules[i]
|
||||
if e.ruleMatches(rule, src, dstDomain, dstAddr, dstPort, proto, path) {
|
||||
e.log.Tracef("rule %s matched: action=%s src=%s domain=%s dst=%s:%d proto=%s path=%s",
|
||||
rule.ID, rule.Action, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
|
||||
return EvalResult{Action: rule.Action, RuleID: rule.ID}
|
||||
}
|
||||
}
|
||||
|
||||
e.log.Tracef("no rule matched, default=%s: src=%s domain=%s dst=%s:%d proto=%s path=%s",
|
||||
e.defaultAction, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
|
||||
return EvalResult{Action: e.defaultAction}
|
||||
}
|
||||
|
||||
// HasPathRulesForDomain returns true if any rule matching the domain has non-empty Paths.
|
||||
// Used to force MITM inspection when path-level rules exist (paths are only visible after decryption).
|
||||
func (e *RuleEngine) HasPathRulesForDomain(dstDomain domain.Domain) bool {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
|
||||
for i := range e.rules {
|
||||
if len(e.rules[i].Paths) > 0 && e.matchDomain(&e.rules[i], dstDomain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches checks whether all non-empty fields of a rule match.
|
||||
// Empty fields are treated as "match any".
|
||||
// All specified fields must match (AND logic).
|
||||
func (e *RuleEngine) ruleMatches(rule *Rule, src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) bool {
|
||||
if !e.matchSource(rule, src) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchDomain(rule, dstDomain) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchNetwork(rule, dstAddr) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchPort(rule, dstPort) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchProtocol(rule, proto) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !e.matchPaths(rule, path) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// matchSource returns true if src matches any of the rule's source CIDRs,
|
||||
// or if no source CIDRs are specified (match any).
|
||||
func (e *RuleEngine) matchSource(rule *Rule, src netip.Addr) bool {
|
||||
if len(rule.Sources) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, prefix := range rule.Sources {
|
||||
if prefix.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchDomain returns true if dstDomain matches any of the rule's domain patterns,
|
||||
// or if no domain patterns are specified (match any).
|
||||
func (e *RuleEngine) matchDomain(rule *Rule, dstDomain domain.Domain) bool {
|
||||
if len(rule.Domains) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If we have domain rules but no domain to match against (e.g., raw IP connection),
|
||||
// the domain condition does not match.
|
||||
if dstDomain == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range rule.Domains {
|
||||
if MatchDomain(pattern, dstDomain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchNetwork returns true if dstAddr is within any of the rule's destination CIDRs,
|
||||
// or if no destination CIDRs are specified (match any).
|
||||
func (e *RuleEngine) matchNetwork(rule *Rule, dstAddr netip.Addr) bool {
|
||||
if len(rule.Networks) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, prefix := range rule.Networks {
|
||||
if prefix.Contains(dstAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchProtocol returns true if proto matches any of the rule's protocols,
|
||||
// or if no protocols are specified (match any).
|
||||
func (e *RuleEngine) matchProtocol(rule *Rule, proto ProtoType) bool {
|
||||
if len(rule.Protocols) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, p := range rule.Protocols {
|
||||
if p == proto {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// matchPort returns true if dstPort matches any of the rule's destination ports,
|
||||
// or if no ports are specified (match any).
|
||||
func (e *RuleEngine) matchPort(rule *Rule, dstPort uint16) bool {
|
||||
if len(rule.Ports) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return slices.Contains(rule.Ports, dstPort)
|
||||
}
|
||||
|
||||
// matchPaths returns true if path matches any of the rule's path patterns,
|
||||
// or if no paths are specified (match any). Empty path (connection-level eval) matches all.
|
||||
func (e *RuleEngine) matchPaths(rule *Rule, path string) bool {
|
||||
if len(rule.Paths) == 0 {
|
||||
return true
|
||||
}
|
||||
// Connection-level (path=""): rules with paths don't match at connection level.
|
||||
// HasPathRulesForDomain forces the connection to inspect, so paths are
|
||||
// checked per-request once the HTTP request is visible.
|
||||
if path == "" {
|
||||
return false
|
||||
}
|
||||
for _, pattern := range rule.Paths {
|
||||
if matchPath(pattern, path) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchPath checks if a URL path matches a pattern.
|
||||
// Supports: exact ("/login"), prefix with wildcard ("/api/*"),
|
||||
// and contains ("*/admin/*"). A bare "*" matches everything.
|
||||
func matchPath(pattern, path string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
hasLeadingStar := strings.HasPrefix(pattern, "*")
|
||||
hasTrailingStar := strings.HasSuffix(pattern, "*")
|
||||
|
||||
switch {
|
||||
case hasLeadingStar && hasTrailingStar:
|
||||
// */admin/* = contains
|
||||
middle := strings.Trim(pattern, "*")
|
||||
return strings.Contains(path, middle)
|
||||
case hasTrailingStar:
|
||||
// /api/* = prefix
|
||||
prefix := strings.TrimSuffix(pattern, "*")
|
||||
return strings.HasPrefix(path, prefix)
|
||||
case hasLeadingStar:
|
||||
// *.json = suffix
|
||||
suffix := strings.TrimPrefix(pattern, "*")
|
||||
return strings.HasSuffix(path, suffix)
|
||||
default:
|
||||
// exact
|
||||
return path == pattern
|
||||
}
|
||||
}
|
||||
@@ -1,338 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func testLogger() *log.Entry {
|
||||
return log.WithField("test", true)
|
||||
}
|
||||
|
||||
func mustDomain(t *testing.T, s string) domain.Domain {
|
||||
t.Helper()
|
||||
d, err := domain.FromString(s)
|
||||
require.NoError(t, err)
|
||||
return d
|
||||
}
|
||||
|
||||
func TestRuleEngine_Evaluate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []Rule
|
||||
defaultAction Action
|
||||
src netip.Addr
|
||||
dstDomain domain.Domain
|
||||
dstAddr netip.Addr
|
||||
dstPort uint16
|
||||
want Action
|
||||
}{
|
||||
{
|
||||
name: "no rules returns default allow",
|
||||
defaultAction: ActionAllow,
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "no rules returns default block",
|
||||
defaultAction: ActionBlock,
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "domain exact match blocks",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "malware.example.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard match blocks",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "phishing.evil.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "domain wildcard does not match base",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "evil.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "case insensitive domain match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "EXAMPLE.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionBlock,
|
||||
},
|
||||
{
|
||||
name: "source CIDR match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Action: ActionInspect,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("192.168.1.50"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionInspect,
|
||||
},
|
||||
{
|
||||
name: "source CIDR no match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.5"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "destination network match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
Action: ActionInspect,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("192.168.1.1"),
|
||||
dstAddr: netip.MustParseAddr("10.50.0.1"),
|
||||
dstPort: 80,
|
||||
want: ActionInspect,
|
||||
},
|
||||
{
|
||||
name: "port match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Ports: []uint16{443, 8443},
|
||||
Action: ActionInspect,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionInspect,
|
||||
},
|
||||
{
|
||||
name: "port no match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Ports: []uint16{443, 8443},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 22,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "priority ordering first match wins",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("allow-internal"),
|
||||
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
|
||||
Action: ActionAllow,
|
||||
Priority: 1,
|
||||
},
|
||||
{
|
||||
ID: id.RuleID("inspect-all"),
|
||||
Action: ActionInspect,
|
||||
Priority: 10,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: mustDomain(t, "api.internal.corp"),
|
||||
dstAddr: netip.MustParseAddr("10.1.0.5"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "all fields must match (AND logic)",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
|
||||
Ports: []uint16{443},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
// Source matches, domain matches, but port doesn't
|
||||
src: netip.MustParseAddr("192.168.1.10"),
|
||||
dstDomain: mustDomain(t, "phish.evil.com"),
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 8080,
|
||||
want: ActionAllow,
|
||||
},
|
||||
{
|
||||
name: "empty domain with domain rule does not match",
|
||||
defaultAction: ActionAllow,
|
||||
rules: []Rule{
|
||||
{
|
||||
ID: id.RuleID("r1"),
|
||||
Domains: []domain.Domain{mustDomain(t, "example.com")},
|
||||
Action: ActionBlock,
|
||||
},
|
||||
},
|
||||
src: netip.MustParseAddr("10.0.0.1"),
|
||||
dstDomain: "", // raw IP connection, no SNI
|
||||
dstAddr: netip.MustParseAddr("1.2.3.4"),
|
||||
dstPort: 443,
|
||||
want: ActionAllow,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), tt.defaultAction)
|
||||
engine.UpdateRules(tt.rules, tt.defaultAction)
|
||||
|
||||
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleEngine_ProtocolMatching(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), ActionAllow)
|
||||
engine.UpdateRules([]Rule{
|
||||
{
|
||||
ID: "block-websocket",
|
||||
Protocols: []ProtoType{ProtoWebSocket},
|
||||
Action: ActionBlock,
|
||||
Priority: 1,
|
||||
},
|
||||
{
|
||||
ID: "inspect-h2",
|
||||
Protocols: []ProtoType{ProtoH2},
|
||||
Action: ActionInspect,
|
||||
Priority: 2,
|
||||
},
|
||||
}, ActionAllow)
|
||||
|
||||
src := netip.MustParseAddr("10.0.0.1")
|
||||
dst := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// WebSocket: blocked by rule
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
|
||||
|
||||
// HTTP/2: inspected by rule
|
||||
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
|
||||
|
||||
// Plain HTTP: no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
|
||||
|
||||
// HTTPS: no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
|
||||
|
||||
// QUIC/H3: no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
|
||||
|
||||
// Empty protocol (unknown): no protocol rule matches, default allow
|
||||
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
|
||||
}
|
||||
|
||||
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), ActionAllow)
|
||||
engine.UpdateRules([]Rule{
|
||||
{
|
||||
ID: "block-all-protos",
|
||||
Action: ActionBlock,
|
||||
// No Protocols field = match all protocols
|
||||
Priority: 1,
|
||||
},
|
||||
}, ActionAllow)
|
||||
|
||||
src := netip.MustParseAddr("10.0.0.1")
|
||||
dst := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
|
||||
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
|
||||
}
|
||||
|
||||
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
|
||||
engine := NewRuleEngine(testLogger(), ActionAllow)
|
||||
|
||||
engine.UpdateRules([]Rule{
|
||||
{ID: "c", Priority: 30, Action: ActionBlock},
|
||||
{ID: "a", Priority: 10, Action: ActionInspect},
|
||||
{ID: "b", Priority: 20, Action: ActionAllow},
|
||||
}, ActionAllow)
|
||||
|
||||
engine.mu.RLock()
|
||||
defer engine.mu.RUnlock()
|
||||
|
||||
require.Len(t, engine.rules, 3)
|
||||
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
|
||||
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
|
||||
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
recordTypeHandshake = 0x16
|
||||
handshakeTypeClientHello = 0x01
|
||||
extensionTypeSNI = 0x0000
|
||||
extensionTypeALPN = 0x0010
|
||||
sniTypeHostName = 0x00
|
||||
|
||||
// maxClientHelloSize is the maximum ClientHello size we'll read.
|
||||
// Real-world ClientHellos are typically under 1KB but can reach ~16KB with
|
||||
// many extensions (post-quantum key shares, etc.).
|
||||
maxClientHelloSize = 16384
|
||||
)
|
||||
|
||||
// ClientHelloInfo holds data extracted from a TLS ClientHello.
|
||||
type ClientHelloInfo struct {
|
||||
SNI domain.Domain
|
||||
ALPN []string
|
||||
}
|
||||
|
||||
// isTLSHandshake reports whether the first byte indicates a TLS handshake record.
|
||||
func isTLSHandshake(b byte) bool {
|
||||
return b == recordTypeHandshake
|
||||
}
|
||||
|
||||
// httpMethods lists the first bytes of valid HTTP method tokens.
|
||||
var httpMethods = [][]byte{
|
||||
[]byte("GET "),
|
||||
[]byte("POST"),
|
||||
[]byte("PUT "),
|
||||
[]byte("DELE"),
|
||||
[]byte("HEAD"),
|
||||
[]byte("OPTI"),
|
||||
[]byte("PATC"),
|
||||
[]byte("CONN"),
|
||||
[]byte("TRAC"),
|
||||
}
|
||||
|
||||
// isHTTPMethod reports whether the peeked bytes look like the start of an HTTP request.
|
||||
func isHTTPMethod(b []byte) bool {
|
||||
if len(b) < 4 {
|
||||
return false
|
||||
}
|
||||
for _, m := range httpMethods {
|
||||
if b[0] == m[0] && b[1] == m[1] && b[2] == m[2] && b[3] == m[3] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseClientHello reads a TLS ClientHello from r and returns SNI and ALPN.
|
||||
func parseClientHello(r io.Reader) (ClientHelloInfo, error) {
|
||||
// TLS record header: type(1) + version(2) + length(2)
|
||||
var recordHeader [5]byte
|
||||
if _, err := io.ReadFull(r, recordHeader[:]); err != nil {
|
||||
return ClientHelloInfo{}, fmt.Errorf("read TLS record header: %w", err)
|
||||
}
|
||||
|
||||
if recordHeader[0] != recordTypeHandshake {
|
||||
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", recordHeader[0])
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(recordHeader[3:5]))
|
||||
if recordLen < 4 || recordLen > maxClientHelloSize {
|
||||
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
|
||||
}
|
||||
|
||||
// Read the full handshake message
|
||||
msg := make([]byte, recordLen)
|
||||
if _, err := io.ReadFull(r, msg); err != nil {
|
||||
return ClientHelloInfo{}, fmt.Errorf("read handshake message: %w", err)
|
||||
}
|
||||
|
||||
return parseClientHelloMsg(msg)
|
||||
}
|
||||
|
||||
// extractSNI reads a TLS ClientHello from r and returns the SNI hostname.
|
||||
// Returns empty domain if no SNI extension is present.
|
||||
func extractSNI(r io.Reader) (domain.Domain, error) {
|
||||
info, err := parseClientHello(r)
|
||||
return info.SNI, err
|
||||
}
|
||||
|
||||
// extractSNIFromBytes parses SNI from raw bytes that start with the TLS record header.
|
||||
func extractSNIFromBytes(data []byte) (domain.Domain, error) {
|
||||
info, err := parseClientHelloFromBytes(data)
|
||||
return info.SNI, err
|
||||
}
|
||||
|
||||
// parseClientHelloFromBytes parses a ClientHello from raw bytes starting with the TLS record header.
|
||||
func parseClientHelloFromBytes(data []byte) (ClientHelloInfo, error) {
|
||||
if len(data) < 5 {
|
||||
return ClientHelloInfo{}, fmt.Errorf("data too short for TLS record header")
|
||||
}
|
||||
|
||||
if data[0] != recordTypeHandshake {
|
||||
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", data[0])
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
if recordLen < 4 {
|
||||
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
|
||||
}
|
||||
|
||||
end := 5 + recordLen
|
||||
if end > len(data) {
|
||||
return ClientHelloInfo{}, fmt.Errorf("TLS record truncated: need %d, have %d", end, len(data))
|
||||
}
|
||||
|
||||
return parseClientHelloMsg(data[5:end])
|
||||
}
|
||||
|
||||
// parseClientHelloMsg extracts SNI and ALPN from a raw ClientHello handshake message.
|
||||
// msg starts at the handshake type byte.
|
||||
func parseClientHelloMsg(msg []byte) (ClientHelloInfo, error) {
|
||||
if len(msg) < 4 {
|
||||
return ClientHelloInfo{}, fmt.Errorf("handshake message too short")
|
||||
}
|
||||
|
||||
if msg[0] != handshakeTypeClientHello {
|
||||
return ClientHelloInfo{}, fmt.Errorf("not a ClientHello (type=%d)", msg[0])
|
||||
}
|
||||
|
||||
// Handshake header: type(1) + length(3)
|
||||
helloLen := int(msg[1])<<16 | int(msg[2])<<8 | int(msg[3])
|
||||
if helloLen+4 > len(msg) {
|
||||
return ClientHelloInfo{}, fmt.Errorf("ClientHello truncated")
|
||||
}
|
||||
|
||||
hello := msg[4 : 4+helloLen]
|
||||
return parseHelloBody(hello)
|
||||
}
|
||||
|
||||
// parseHelloBody parses the ClientHello body (after handshake header)
|
||||
// and extracts SNI and ALPN.
|
||||
func parseHelloBody(hello []byte) (ClientHelloInfo, error) {
|
||||
// ClientHello structure:
|
||||
// version(2) + random(32) + session_id_len(1) + session_id(var)
|
||||
// + cipher_suites_len(2) + cipher_suites(var)
|
||||
// + compression_len(1) + compression(var)
|
||||
// + extensions_len(2) + extensions(var)
|
||||
|
||||
var info ClientHelloInfo
|
||||
|
||||
if len(hello) < 35 {
|
||||
return info, fmt.Errorf("ClientHello body too short")
|
||||
}
|
||||
|
||||
pos := 2 + 32 // skip version + random
|
||||
|
||||
// Skip session ID
|
||||
if pos >= len(hello) {
|
||||
return info, fmt.Errorf("ClientHello truncated at session ID")
|
||||
}
|
||||
sessionIDLen := int(hello[pos])
|
||||
pos += 1 + sessionIDLen
|
||||
|
||||
// Skip cipher suites
|
||||
if pos+2 > len(hello) {
|
||||
return info, fmt.Errorf("ClientHello truncated at cipher suites")
|
||||
}
|
||||
cipherLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
|
||||
pos += 2 + cipherLen
|
||||
|
||||
// Skip compression methods
|
||||
if pos >= len(hello) {
|
||||
return info, fmt.Errorf("ClientHello truncated at compression")
|
||||
}
|
||||
compLen := int(hello[pos])
|
||||
pos += 1 + compLen
|
||||
|
||||
// Extensions
|
||||
if pos+2 > len(hello) {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
extLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
|
||||
pos += 2
|
||||
|
||||
extEnd := pos + extLen
|
||||
if extEnd > len(hello) {
|
||||
return info, fmt.Errorf("extensions block truncated")
|
||||
}
|
||||
|
||||
// Walk extensions looking for SNI and ALPN
|
||||
for pos+4 <= extEnd {
|
||||
extType := binary.BigEndian.Uint16(hello[pos : pos+2])
|
||||
extDataLen := int(binary.BigEndian.Uint16(hello[pos+2 : pos+4]))
|
||||
pos += 4
|
||||
|
||||
if pos+extDataLen > extEnd {
|
||||
return info, fmt.Errorf("extension data truncated")
|
||||
}
|
||||
|
||||
switch extType {
|
||||
case extensionTypeSNI:
|
||||
sni, err := parseSNIExtension(hello[pos : pos+extDataLen])
|
||||
if err != nil {
|
||||
return info, err
|
||||
}
|
||||
info.SNI = sni
|
||||
case extensionTypeALPN:
|
||||
info.ALPN = parseALPNExtension(hello[pos : pos+extDataLen])
|
||||
}
|
||||
|
||||
pos += extDataLen
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// parseALPNExtension parses the ALPN extension data and returns protocol names.
|
||||
// ALPN extension: list_length(2) + entries (each: len(1) + protocol_name(var))
|
||||
func parseALPNExtension(data []byte) []string {
|
||||
if len(data) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
listLen := int(binary.BigEndian.Uint16(data[0:2]))
|
||||
if listLen+2 > len(data) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var protocols []string
|
||||
pos := 2
|
||||
end := 2 + listLen
|
||||
|
||||
for pos < end {
|
||||
if pos >= len(data) {
|
||||
break
|
||||
}
|
||||
nameLen := int(data[pos])
|
||||
pos++
|
||||
if pos+nameLen > end {
|
||||
break
|
||||
}
|
||||
protocols = append(protocols, string(data[pos:pos+nameLen]))
|
||||
pos += nameLen
|
||||
}
|
||||
|
||||
return protocols
|
||||
}
|
||||
|
||||
// parseSNIExtension parses the SNI extension data and returns the hostname.
|
||||
func parseSNIExtension(data []byte) (domain.Domain, error) {
|
||||
// SNI extension: list_length(2) + entries
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("SNI extension too short")
|
||||
}
|
||||
|
||||
listLen := int(binary.BigEndian.Uint16(data[0:2]))
|
||||
if listLen+2 > len(data) {
|
||||
return "", fmt.Errorf("SNI list truncated")
|
||||
}
|
||||
|
||||
pos := 2
|
||||
end := 2 + listLen
|
||||
|
||||
for pos+3 <= end {
|
||||
nameType := data[pos]
|
||||
nameLen := int(binary.BigEndian.Uint16(data[pos+1 : pos+3]))
|
||||
pos += 3
|
||||
|
||||
if pos+nameLen > end {
|
||||
return "", fmt.Errorf("SNI name truncated")
|
||||
}
|
||||
|
||||
if nameType == sniTypeHostName {
|
||||
hostname := string(data[pos : pos+nameLen])
|
||||
return domain.FromString(hostname)
|
||||
}
|
||||
|
||||
pos += nameLen
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractSNI(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sni string
|
||||
wantSNI string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "standard domain",
|
||||
sni: "example.com",
|
||||
wantSNI: "example.com",
|
||||
},
|
||||
{
|
||||
name: "subdomain",
|
||||
sni: "api.staging.example.com",
|
||||
wantSNI: "api.staging.example.com",
|
||||
},
|
||||
{
|
||||
name: "mixed case normalized to lowercase",
|
||||
sni: "Example.COM",
|
||||
wantSNI: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientHello := buildClientHello(t, tt.sni)
|
||||
|
||||
sni, err := extractSNI(bytes.NewReader(clientHello))
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.wantSNI, sni.PunycodeString())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractSNI_NotTLS(t *testing.T) {
|
||||
// HTTP request instead of TLS
|
||||
data := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
_, err := extractSNI(bytes.NewReader(data))
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not a TLS handshake")
|
||||
}
|
||||
|
||||
func TestExtractSNI_Truncated(t *testing.T) {
|
||||
// Just the record header, no body
|
||||
data := []byte{0x16, 0x03, 0x01, 0x00, 0x05}
|
||||
_, err := extractSNI(bytes.NewReader(data))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestExtractSNIFromBytes(t *testing.T) {
|
||||
clientHello := buildClientHello(t, "test.example.com")
|
||||
|
||||
sni, err := extractSNIFromBytes(clientHello)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test.example.com", sni.PunycodeString())
|
||||
}
|
||||
|
||||
// buildClientHello generates a real TLS ClientHello with the given SNI.
|
||||
func buildClientHello(t *testing.T, serverName string) []byte {
|
||||
t.Helper()
|
||||
|
||||
// Use a pipe to capture the ClientHello bytes
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
done := make(chan []byte, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 4096)
|
||||
n, _ := serverConn.Read(buf)
|
||||
done <- buf[:n]
|
||||
serverConn.Close()
|
||||
}()
|
||||
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: serverName,
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
|
||||
// Trigger the handshake (will fail since server isn't TLS, but we capture the ClientHello)
|
||||
go func() {
|
||||
_ = tlsConn.Handshake()
|
||||
tlsConn.Close()
|
||||
}()
|
||||
|
||||
clientHello := <-done
|
||||
clientConn.Close()
|
||||
|
||||
require.True(t, len(clientHello) > 5, "ClientHello too short")
|
||||
require.Equal(t, byte(0x16), clientHello[0], "not a TLS handshake record")
|
||||
|
||||
return clientHello
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
package inspect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// handleTLS processes a TLS connection for the kernel-mode path: extracts SNI,
|
||||
// evaluates rules, and handles the connection internally.
|
||||
// In envoy mode, allowed connections are forwarded to envoy instead of direct relay.
|
||||
func (p *Proxy) handleTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
|
||||
result, err := p.inspectTLS(ctx, pconn, dst, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if result.PassthroughConn != nil {
|
||||
p.mu.RLock()
|
||||
envoy := p.envoy
|
||||
p.mu.RUnlock()
|
||||
|
||||
if envoy != nil {
|
||||
return p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
|
||||
}
|
||||
return p.tlsPassthrough(ctx, pconn, dst, "")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// inspectTLS extracts SNI, evaluates rules, and returns the result.
|
||||
// For ActionAllow: returns the peekConn as PassthroughConn (caller relays).
|
||||
// For ActionBlock/ActionInspect: handles internally and returns nil PassthroughConn.
|
||||
func (p *Proxy) inspectTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
|
||||
// The first 5 bytes (TLS record header) are already peeked.
|
||||
// Extend to read the full TLS record so bytes remain in the buffer for passthrough.
|
||||
peeked := pconn.Peeked()
|
||||
recordLen := int(peeked[3])<<8 | int(peeked[4])
|
||||
if _, err := pconn.PeekMore(5 + recordLen); err != nil {
|
||||
return InspectResult{}, fmt.Errorf("read TLS record: %w", err)
|
||||
}
|
||||
|
||||
hello, err := parseClientHelloFromBytes(pconn.Peeked())
|
||||
if err != nil {
|
||||
return InspectResult{}, fmt.Errorf("parse ClientHello: %w", err)
|
||||
}
|
||||
|
||||
sni := hello.SNI
|
||||
proto := protoFromALPN(hello.ALPN)
|
||||
// Connection-level evaluation: pass empty path.
|
||||
action := p.evaluateAction(src.IP, sni, dst, proto, "")
|
||||
|
||||
// If any rule for this domain has path patterns, force inspect so paths can
|
||||
// be checked per-request after MITM decryption.
|
||||
if action == ActionAllow && p.rules.HasPathRulesForDomain(sni) {
|
||||
p.log.Debugf("upgrading to inspect for %s (path rules exist)", sni.PunycodeString())
|
||||
action = ActionInspect
|
||||
}
|
||||
|
||||
// Snapshot cert provider under lock for use in this connection.
|
||||
p.mu.RLock()
|
||||
certs := p.certs
|
||||
p.mu.RUnlock()
|
||||
|
||||
switch action {
|
||||
case ActionBlock:
|
||||
p.log.Debugf("block: TLS to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
if certs != nil {
|
||||
return InspectResult{Action: ActionBlock}, p.tlsBlockPage(ctx, pconn, sni, certs)
|
||||
}
|
||||
return InspectResult{Action: ActionBlock}, ErrBlocked
|
||||
|
||||
case ActionAllow:
|
||||
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
|
||||
|
||||
case ActionInspect:
|
||||
if certs == nil {
|
||||
p.log.Warnf("allow: %s (inspect requested but no MITM CA configured)", sni.PunycodeString())
|
||||
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
|
||||
}
|
||||
err := p.tlsMITM(ctx, pconn, dst, sni, src, certs)
|
||||
return InspectResult{Action: ActionInspect}, err
|
||||
|
||||
default:
|
||||
p.log.Warnf("block: unknown action %q for %s", action, sni.PunycodeString())
|
||||
return InspectResult{Action: ActionBlock}, ErrBlocked
|
||||
}
|
||||
}
|
||||
|
||||
// tlsBlockPage completes a MITM TLS handshake with the client using a dynamic
|
||||
// certificate, then serves an HTTP 403 block page so the user sees a clear
|
||||
// message instead of a cryptic SSL error.
|
||||
func (p *Proxy) tlsBlockPage(ctx context.Context, pconn *peekConn, sni domain.Domain, certs *CertProvider) error {
|
||||
hostname := sni.PunycodeString()
|
||||
|
||||
// Force HTTP/1.1 only: block pages are simple responses, no need for h2
|
||||
tlsCfg := certs.GetTLSConfig()
|
||||
tlsCfg.NextProtos = []string{"http/1.1"}
|
||||
clientTLS := tls.Server(pconn, tlsCfg)
|
||||
if err := clientTLS.HandshakeContext(ctx); err != nil {
|
||||
// Client may not trust our CA, handshake fails. That's expected.
|
||||
return fmt.Errorf("block page TLS handshake for %s: %w", hostname, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := clientTLS.Close(); err != nil {
|
||||
p.log.Debugf("close block page TLS for %s: %v", hostname, err)
|
||||
}
|
||||
}()
|
||||
|
||||
writeBlockResponse(clientTLS, nil, sni)
|
||||
return ErrBlocked
|
||||
}
|
||||
|
||||
// tlsPassthrough connects to the destination and relays encrypted traffic
|
||||
// without decryption. The peeked ClientHello bytes are replayed.
|
||||
func (p *Proxy) tlsPassthrough(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain) error {
|
||||
remote, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", dst, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remote.Close(); err != nil {
|
||||
p.log.Debugf("close remote for %s: %v", dst, err)
|
||||
}
|
||||
}()
|
||||
|
||||
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
|
||||
|
||||
return relay(ctx, pconn, remote)
|
||||
}
|
||||
|
||||
// tlsMITM terminates the client TLS connection with a dynamic certificate,
|
||||
// establishes a new TLS connection to the real destination, and runs the
|
||||
// HTTP inspection pipeline on the decrypted traffic.
|
||||
func (p *Proxy) tlsMITM(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, certs *CertProvider) error {
|
||||
hostname := sni.PunycodeString()
|
||||
|
||||
// TLS handshake with client using dynamic cert
|
||||
clientTLS := tls.Server(pconn, certs.GetTLSConfig())
|
||||
if err := clientTLS.HandshakeContext(ctx); err != nil {
|
||||
return fmt.Errorf("client TLS handshake for %s: %w", hostname, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := clientTLS.Close(); err != nil {
|
||||
p.log.Debugf("close client TLS for %s: %v", hostname, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// TLS connection to real destination
|
||||
remoteTLS, err := p.dialTLS(ctx, dst, hostname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial TLS %s (%s): %w", dst, hostname, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := remoteTLS.Close(); err != nil {
|
||||
p.log.Debugf("close remote TLS for %s: %v", hostname, err)
|
||||
}
|
||||
}()
|
||||
|
||||
negotiatedProto := clientTLS.ConnectionState().NegotiatedProtocol
|
||||
p.log.Tracef("inspect: MITM established for %s (proto=%s)", hostname, negotiatedProto)
|
||||
|
||||
return p.inspectHTTP(ctx, clientTLS, remoteTLS, dst, sni, src, negotiatedProto)
|
||||
}
|
||||
|
||||
// dialTLS connects to the destination with TLS, verifying the real server certificate.
|
||||
func (p *Proxy) dialTLS(ctx context.Context, dst netip.AddrPort, serverName string) (net.Conn, error) {
|
||||
rawConn, err := p.dialTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(rawConn, &tls.Config{
|
||||
ServerName: serverName,
|
||||
NextProtos: []string{"h2", "http/1.1"},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
|
||||
if err := tlsConn.HandshakeContext(ctx); err != nil {
|
||||
if closeErr := rawConn.Close(); closeErr != nil {
|
||||
p.log.Debugf("close raw conn after TLS handshake failure: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("TLS handshake with %s: %w", serverName, err)
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// protoFromALPN maps TLS ALPN protocol names to proxy ProtoType.
|
||||
// Falls back to ProtoHTTPS when no recognized ALPN is present.
|
||||
func protoFromALPN(alpn []string) ProtoType {
|
||||
for _, p := range alpn {
|
||||
switch p {
|
||||
case "h2":
|
||||
return ProtoH2
|
||||
case "h3": // unlikely in TLS, but handle anyway
|
||||
return ProtoH3
|
||||
}
|
||||
}
|
||||
// No ALPN or only "http/1.1": treat as HTTPS
|
||||
return ProtoHTTPS
|
||||
}
|
||||
|
||||
// relay copies data bidirectionally between client and remote until one
|
||||
// side closes or the context is cancelled.
|
||||
func relay(ctx context.Context, client, remote net.Conn) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(remote, client)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(client, remote)
|
||||
cancel()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
var firstErr error
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil && firstErr == nil {
|
||||
if !isClosedErr(err) {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return firstErr
|
||||
}
|
||||
|
||||
// evaluateAction runs rule evaluation and resolves the effective action.
|
||||
// Pass empty path for connection-level (TLS), non-empty for request-level (HTTP).
|
||||
func (p *Proxy) evaluateAction(src netip.Addr, sni domain.Domain, dst netip.AddrPort, proto ProtoType, path string) Action {
|
||||
return p.rules.Evaluate(src, sni, dst.Addr(), dst.Port(), proto, path)
|
||||
}
|
||||
|
||||
// dialTCP dials the destination, blocking connections to loopback, link-local,
|
||||
// multicast, and WG overlay network addresses.
|
||||
func (p *Proxy) dialTCP(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
ip := dst.Addr().Unmap()
|
||||
if err := p.validateDialTarget(ip); err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", dst, err)
|
||||
}
|
||||
return p.dialer.DialContext(ctx, "tcp", dst.String())
|
||||
}
|
||||
|
||||
// validateDialTarget blocks destinations that should never be dialed by the proxy.
|
||||
// Mirrors the route validation in systemops.validateRoute.
|
||||
func (p *Proxy) validateDialTarget(addr netip.Addr) error {
|
||||
switch {
|
||||
case !addr.IsValid():
|
||||
return fmt.Errorf("invalid address")
|
||||
case addr.IsLoopback():
|
||||
return fmt.Errorf("loopback address not allowed")
|
||||
case addr.IsLinkLocalUnicast(), addr.IsLinkLocalMulticast(), addr.IsInterfaceLocalMulticast():
|
||||
return fmt.Errorf("link-local address not allowed")
|
||||
case addr.IsMulticast():
|
||||
return fmt.Errorf("multicast address not allowed")
|
||||
case p.wgNetwork.IsValid() && p.wgNetwork.Contains(addr):
|
||||
return fmt.Errorf("overlay network address not allowed")
|
||||
case p.localIPs != nil && p.localIPs.IsLocalIP(addr):
|
||||
return fmt.Errorf("local address not allowed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isClosedErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return err == io.EOF ||
|
||||
err == io.ErrClosedPipe ||
|
||||
err == net.ErrClosed ||
|
||||
err == context.Canceled
|
||||
}
|
||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
@@ -327,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
a.config.EnableSSHRemotePortForwarding,
|
||||
a.config.DisableSSHAuth,
|
||||
a.config.DisableVNCAuth,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -543,11 +543,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DisableVNCAuth: config.DisableVNCAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
@@ -562,9 +564,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
|
||||
InspectionCACertPath: config.InspectionCACertPath,
|
||||
InspectionCAKeyPath: config.InspectionCAKeyPath,
|
||||
|
||||
ProfileConfig: config,
|
||||
}
|
||||
|
||||
@@ -627,6 +626,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
@@ -639,6 +639,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
config.DisableVNCAuth,
|
||||
)
|
||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -118,11 +117,13 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -137,12 +138,6 @@ type EngineConfig struct {
|
||||
|
||||
MTU uint16
|
||||
|
||||
// InspectionCACertPath is a local CA cert for transparent proxy MITM.
|
||||
// Takes priority over management-pushed CA.
|
||||
InspectionCACertPath string
|
||||
// InspectionCAKeyPath is the corresponding private key.
|
||||
InspectionCAKeyPath string
|
||||
|
||||
// for debug bundle generation
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
@@ -204,6 +199,7 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -229,10 +225,6 @@ type Engine struct {
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// transparentProxy is the transparent forward proxy for traffic inspection.
|
||||
transparentProxy *inspect.Proxy
|
||||
udpInspectionHookID string
|
||||
|
||||
// auto-update
|
||||
updateManager *updater.Manager
|
||||
|
||||
@@ -321,6 +313,10 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -1008,6 +1004,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1020,6 +1017,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -1047,6 +1045,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||
@@ -1148,6 +1150,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1160,6 +1163,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1283,9 +1287,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
|
||||
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
|
||||
|
||||
// Transparent proxy
|
||||
e.updateTransparentProxy(networkMap.GetTransparentProxyConfig())
|
||||
|
||||
// Ingress forward rules
|
||||
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
|
||||
if err != nil {
|
||||
@@ -1337,6 +1338,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
// VNC auth: use dedicated VNCAuth if present.
|
||||
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
|
||||
e.updateVNCServerAuth(vncAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1709,8 +1715,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
||||
func (e *Engine) close() {
|
||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||
|
||||
e.stopTransparentProxy()
|
||||
|
||||
if e.wgInterface != nil {
|
||||
if err := e.wgInterface.Close(); err != nil {
|
||||
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
|
||||
@@ -1748,6 +1752,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1760,6 +1765,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
e.config.DisableVNCAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
|
||||
@@ -1,571 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// updateTransparentProxy processes transparent proxy configuration from the network map.
|
||||
func (e *Engine) updateTransparentProxy(cfg *mgmProto.TransparentProxyConfig) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
if cfg == nil {
|
||||
log.Tracef("inspect: config is nil")
|
||||
} else {
|
||||
log.Tracef("inspect: config disabled")
|
||||
}
|
||||
// Only stop if explicitly disabled. Don't stop on nil config to avoid
|
||||
// a gap during policy edits where management briefly pushes empty config.
|
||||
if cfg != nil && !cfg.Enabled {
|
||||
e.stopTransparentProxy()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("inspect: config received: enabled=%v mode=%v default_action=%v rules=%d has_ca=%v",
|
||||
cfg.Enabled, cfg.Mode, cfg.DefaultAction, len(cfg.Rules), len(cfg.CaCertPem) > 0)
|
||||
|
||||
// BlockInbound prevents adding TPROXY rules since kernel TPROXY bypasses ACLs.
|
||||
// The userspace forwarder path still works as it operates within the forwarder hook.
|
||||
if e.config.BlockInbound {
|
||||
log.Warnf("inspect: BlockInbound is set, skipping redirect rules (userspace path still active)")
|
||||
}
|
||||
|
||||
proxyConfig, err := toProxyConfig(cfg)
|
||||
if err != nil {
|
||||
log.Errorf("inspect: parse config: %v", err)
|
||||
e.stopTransparentProxy()
|
||||
return
|
||||
}
|
||||
|
||||
// CA priority: local config > management-pushed > auto-generated self-signed.
|
||||
// Local wins over mgmt to prevent compromised management from injecting a CA.
|
||||
e.resolveInspectionCA(&proxyConfig)
|
||||
|
||||
if e.transparentProxy != nil {
|
||||
// Mode change requires full recreate (envoy lifecycle, listener changes).
|
||||
if proxyConfig.Mode != e.transparentProxy.Mode() {
|
||||
log.Infof("inspect: mode changed to %s, recreating engine", proxyConfig.Mode)
|
||||
e.stopTransparentProxy()
|
||||
} else {
|
||||
e.transparentProxy.UpdateConfig(proxyConfig)
|
||||
e.syncTProxyRules(proxyConfig)
|
||||
e.syncUDPInspectionHook()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if e.wgInterface != nil {
|
||||
proxyConfig.WGNetwork = e.wgInterface.Address().Network
|
||||
proxyConfig.ListenAddr = netip.AddrPortFrom(
|
||||
e.wgInterface.Address().IP.Unmap(),
|
||||
proxyConfig.ListenAddr.Port(),
|
||||
)
|
||||
}
|
||||
|
||||
// Pass local IP checker for SSRF prevention
|
||||
if checker, ok := e.firewall.(inspect.LocalIPChecker); ok {
|
||||
proxyConfig.LocalIPChecker = checker
|
||||
}
|
||||
|
||||
p, err := inspect.New(e.ctx, log.WithField("component", "inspect"), proxyConfig)
|
||||
if err != nil {
|
||||
log.Errorf("inspect: start engine: %v", err)
|
||||
return
|
||||
}
|
||||
e.transparentProxy = p
|
||||
|
||||
e.attachProxyToForwarder(p)
|
||||
e.syncTProxyRules(proxyConfig)
|
||||
e.syncUDPInspectionHook()
|
||||
|
||||
log.Infof("inspect: engine started (mode=%s, rules=%d)", proxyConfig.Mode, len(proxyConfig.Rules))
|
||||
}
|
||||
|
||||
// stopTransparentProxy shuts down the transparent proxy and removes interception.
|
||||
func (e *Engine) stopTransparentProxy() {
|
||||
if e.transparentProxy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.attachProxyToForwarder(nil)
|
||||
e.removeTProxyRule()
|
||||
e.removeUDPInspectionHook()
|
||||
|
||||
if err := e.transparentProxy.Close(); err != nil {
|
||||
log.Debugf("inspect: close engine: %v", err)
|
||||
}
|
||||
e.transparentProxy = nil
|
||||
|
||||
log.Info("inspect: engine stopped")
|
||||
}
|
||||
|
||||
const tproxyRuleID = "tproxy-redirect"
|
||||
|
||||
// syncTProxyRules adds a TPROXY rule via the firewall manager to intercept
|
||||
// matching traffic on the WG interface and redirect it to the proxy socket.
|
||||
func (e *Engine) syncTProxyRules(config inspect.Config) {
|
||||
if e.config.BlockInbound {
|
||||
e.removeTProxyRule()
|
||||
return
|
||||
}
|
||||
|
||||
var listenPort uint16
|
||||
if e.transparentProxy != nil {
|
||||
listenPort = e.transparentProxy.ListenPort()
|
||||
}
|
||||
if listenPort == 0 {
|
||||
e.removeTProxyRule()
|
||||
return
|
||||
}
|
||||
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
|
||||
dstPorts := make([]uint16, len(config.RedirectPorts))
|
||||
copy(dstPorts, config.RedirectPorts)
|
||||
|
||||
log.Debugf("inspect: syncing redirect rules: listen port %d, redirect ports %v, sources %v",
|
||||
listenPort, dstPorts, config.RedirectSources)
|
||||
|
||||
if err := e.firewall.AddTProxyRule(tproxyRuleID, config.RedirectSources, dstPorts, listenPort); err != nil {
|
||||
log.Errorf("inspect: add redirect rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// removeTProxyRule removes the TPROXY redirect rule.
|
||||
func (e *Engine) removeTProxyRule() {
|
||||
if e.firewall == nil {
|
||||
return
|
||||
}
|
||||
if err := e.firewall.RemoveTProxyRule(tproxyRuleID); err != nil {
|
||||
log.Debugf("inspect: remove redirect rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// syncUDPInspectionHook registers a UDP packet hook on port 443 for QUIC SNI blocking.
|
||||
// The hook is called by the USP filter for each UDP packet matching the port,
|
||||
// allowing the inspection engine to extract QUIC SNI and block by domain.
|
||||
func (e *Engine) syncUDPInspectionHook() {
|
||||
e.removeUDPInspectionHook()
|
||||
|
||||
if e.firewall == nil || e.transparentProxy == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p := e.transparentProxy
|
||||
hookID := e.firewall.AddUDPInspectionHook(443, func(packet []byte) bool {
|
||||
srcIP, dstIP, dstPort, udpPayload, ok := parseUDPPacket(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
src := inspect.SourceInfo{IP: srcIP}
|
||||
dst := netip.AddrPortFrom(dstIP, dstPort)
|
||||
action := p.HandleUDPPacket(udpPayload, dst, src)
|
||||
return action == inspect.ActionBlock
|
||||
})
|
||||
|
||||
e.udpInspectionHookID = hookID
|
||||
log.Debugf("inspect: registered UDP inspection hook on port 443 (id=%s)", hookID)
|
||||
}
|
||||
|
||||
// removeUDPInspectionHook removes the QUIC inspection hook.
|
||||
func (e *Engine) removeUDPInspectionHook() {
|
||||
if e.udpInspectionHookID == "" || e.firewall == nil {
|
||||
return
|
||||
}
|
||||
e.firewall.RemoveUDPInspectionHook(e.udpInspectionHookID)
|
||||
e.udpInspectionHookID = ""
|
||||
}
|
||||
|
||||
// parseUDPPacket extracts source/destination IP, destination port, and UDP
|
||||
// payload from a raw IP packet. Supports both IPv4 and IPv6.
|
||||
func parseUDPPacket(packet []byte) (srcIP, dstIP netip.Addr, dstPort uint16, payload []byte, ok bool) {
|
||||
if len(packet) < 1 {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
|
||||
var udpOffset int
|
||||
switch version {
|
||||
case 4:
|
||||
if len(packet) < 20 {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
ihl := int(packet[0]&0x0f) * 4
|
||||
if len(packet) < ihl+8 {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
var srcOK, dstOK bool
|
||||
srcIP, srcOK = netip.AddrFromSlice(packet[12:16])
|
||||
dstIP, dstOK = netip.AddrFromSlice(packet[16:20])
|
||||
if !srcOK || !dstOK {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
udpOffset = ihl
|
||||
|
||||
case 6:
|
||||
// IPv6 fixed header is 40 bytes. Next header must be UDP (17).
|
||||
if len(packet) < 48 { // 40 header + 8 UDP
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
nextHeader := packet[6]
|
||||
if nextHeader != 17 { // not UDP (may have extension headers)
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
var srcOK, dstOK bool
|
||||
srcIP, srcOK = netip.AddrFromSlice(packet[8:24])
|
||||
dstIP, dstOK = netip.AddrFromSlice(packet[24:40])
|
||||
if !srcOK || !dstOK {
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
udpOffset = 40
|
||||
|
||||
default:
|
||||
return srcIP, dstIP, 0, nil, false
|
||||
}
|
||||
|
||||
srcIP = srcIP.Unmap()
|
||||
dstIP = dstIP.Unmap()
|
||||
dstPort = uint16(packet[udpOffset+2])<<8 | uint16(packet[udpOffset+3])
|
||||
payload = packet[udpOffset+8:]
|
||||
|
||||
return srcIP, dstIP, dstPort, payload, true
|
||||
}
|
||||
|
||||
// attachProxyToForwarder sets or clears the proxy on the userspace forwarder.
|
||||
func (e *Engine) attachProxyToForwarder(p *inspect.Proxy) {
|
||||
type forwarderGetter interface {
|
||||
GetForwarder() *forwarder.Forwarder
|
||||
}
|
||||
|
||||
if fg, ok := e.firewall.(forwarderGetter); ok {
|
||||
if fwd := fg.GetForwarder(); fwd != nil {
|
||||
fwd.SetProxy(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// toProxyConfig converts a proto TransparentProxyConfig to the inspect.Config type.
|
||||
func toProxyConfig(cfg *mgmProto.TransparentProxyConfig) (inspect.Config, error) {
|
||||
config := inspect.Config{
|
||||
Enabled: cfg.Enabled,
|
||||
DefaultAction: toProxyAction(cfg.DefaultAction),
|
||||
}
|
||||
|
||||
switch cfg.Mode {
|
||||
case mgmProto.TransparentProxyMode_TP_MODE_ENVOY:
|
||||
config.Mode = inspect.ModeEnvoy
|
||||
case mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL:
|
||||
config.Mode = inspect.ModeExternal
|
||||
default:
|
||||
config.Mode = inspect.ModeBuiltin
|
||||
}
|
||||
|
||||
if cfg.ExternalProxyUrl != "" {
|
||||
u, err := url.Parse(cfg.ExternalProxyUrl)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse external proxy URL: %w", err)
|
||||
}
|
||||
config.ExternalURL = u
|
||||
}
|
||||
|
||||
for _, s := range cfg.RedirectSources {
|
||||
prefix, err := netip.ParsePrefix(s)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse redirect source %q: %w", s, err)
|
||||
}
|
||||
config.RedirectSources = append(config.RedirectSources, prefix)
|
||||
}
|
||||
|
||||
for _, p := range cfg.RedirectPorts {
|
||||
config.RedirectPorts = append(config.RedirectPorts, uint16(p))
|
||||
}
|
||||
|
||||
// TPROXY listen port: fixed default, overridable via env var.
|
||||
if config.Mode == inspect.ModeBuiltin {
|
||||
port := uint16(inspect.DefaultTProxyPort)
|
||||
if v := os.Getenv("NB_TPROXY_PORT"); v != "" {
|
||||
if p, err := strconv.ParseUint(v, 10, 16); err == nil {
|
||||
port = uint16(p)
|
||||
} else {
|
||||
log.Warnf("invalid NB_TPROXY_PORT %q, using default %d", v, inspect.DefaultTProxyPort)
|
||||
}
|
||||
}
|
||||
config.ListenAddr = netip.AddrPortFrom(netip.IPv4Unspecified(), port)
|
||||
}
|
||||
|
||||
for _, r := range cfg.Rules {
|
||||
rule, err := toProxyRule(r)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse rule %q: %w", r.Id, err)
|
||||
}
|
||||
config.Rules = append(config.Rules, rule)
|
||||
}
|
||||
|
||||
if cfg.Icap != nil {
|
||||
icapCfg, err := toICAPConfig(cfg.Icap)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse ICAP config: %w", err)
|
||||
}
|
||||
config.ICAP = icapCfg
|
||||
}
|
||||
|
||||
if len(cfg.CaCertPem) > 0 && len(cfg.CaKeyPem) > 0 {
|
||||
tlsCfg, err := parseTLSConfig(cfg.CaCertPem, cfg.CaKeyPem)
|
||||
if err != nil {
|
||||
return inspect.Config{}, fmt.Errorf("parse TLS config: %w", err)
|
||||
}
|
||||
config.TLS = tlsCfg
|
||||
}
|
||||
|
||||
if config.Mode == inspect.ModeEnvoy {
|
||||
envCfg := &inspect.EnvoyConfig{
|
||||
BinaryPath: cfg.EnvoyBinaryPath,
|
||||
AdminPort: uint16(cfg.EnvoyAdminPort),
|
||||
}
|
||||
if cfg.EnvoySnippets != nil {
|
||||
envCfg.Snippets = &inspect.EnvoySnippets{
|
||||
HTTPFilters: cfg.EnvoySnippets.HttpFilters,
|
||||
NetworkFilters: cfg.EnvoySnippets.NetworkFilters,
|
||||
Clusters: cfg.EnvoySnippets.Clusters,
|
||||
}
|
||||
}
|
||||
config.Envoy = envCfg
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func toProxyRule(r *mgmProto.TransparentProxyRule) (inspect.Rule, error) {
|
||||
rule := inspect.Rule{
|
||||
ID: id.RuleID(r.Id),
|
||||
Action: toProxyAction(r.Action),
|
||||
Priority: int(r.Priority),
|
||||
}
|
||||
|
||||
for _, d := range r.Domains {
|
||||
dom, err := domain.FromString(d)
|
||||
if err != nil {
|
||||
return inspect.Rule{}, fmt.Errorf("parse domain %q: %w", d, err)
|
||||
}
|
||||
rule.Domains = append(rule.Domains, dom)
|
||||
}
|
||||
|
||||
for _, n := range r.Networks {
|
||||
prefix, err := netip.ParsePrefix(n)
|
||||
if err != nil {
|
||||
return inspect.Rule{}, fmt.Errorf("parse network %q: %w", n, err)
|
||||
}
|
||||
rule.Networks = append(rule.Networks, prefix)
|
||||
}
|
||||
|
||||
for _, p := range r.Ports {
|
||||
rule.Ports = append(rule.Ports, uint16(p))
|
||||
}
|
||||
|
||||
for _, proto := range r.Protocols {
|
||||
rule.Protocols = append(rule.Protocols, toProxyProtoType(proto))
|
||||
}
|
||||
|
||||
rule.Paths = r.Paths
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func toProxyProtoType(p mgmProto.TransparentProxyProtocol) inspect.ProtoType {
|
||||
switch p {
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTP:
|
||||
return inspect.ProtoHTTP
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTPS:
|
||||
return inspect.ProtoHTTPS
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_H2:
|
||||
return inspect.ProtoH2
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_H3:
|
||||
return inspect.ProtoH3
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_WEBSOCKET:
|
||||
return inspect.ProtoWebSocket
|
||||
case mgmProto.TransparentProxyProtocol_TP_PROTO_OTHER:
|
||||
return inspect.ProtoOther
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func toProxyAction(a mgmProto.TransparentProxyAction) inspect.Action {
|
||||
switch a {
|
||||
case mgmProto.TransparentProxyAction_TP_ACTION_BLOCK:
|
||||
return inspect.ActionBlock
|
||||
case mgmProto.TransparentProxyAction_TP_ACTION_INSPECT:
|
||||
return inspect.ActionInspect
|
||||
default:
|
||||
return inspect.ActionAllow
|
||||
}
|
||||
}
|
||||
|
||||
func toICAPConfig(cfg *mgmProto.TransparentProxyICAPConfig) (*inspect.ICAPConfig, error) {
|
||||
icap := &inspect.ICAPConfig{
|
||||
MaxConnections: int(cfg.MaxConnections),
|
||||
}
|
||||
|
||||
if cfg.ReqmodUrl != "" {
|
||||
u, err := url.Parse(cfg.ReqmodUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP reqmod URL: %w", err)
|
||||
}
|
||||
icap.ReqModURL = u
|
||||
}
|
||||
|
||||
if cfg.RespmodUrl != "" {
|
||||
u, err := url.Parse(cfg.RespmodUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse ICAP respmod URL: %w", err)
|
||||
}
|
||||
icap.RespModURL = u
|
||||
}
|
||||
|
||||
return icap, nil
|
||||
}
|
||||
|
||||
func parseTLSConfig(certPEM, keyPEM []byte) (*inspect.TLSConfig, error) {
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("decode CA certificate PEM")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse CA certificate: %w", err)
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode(keyPEM)
|
||||
if keyBlock == nil {
|
||||
return nil, fmt.Errorf("decode CA key PEM")
|
||||
}
|
||||
|
||||
key, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
// Try PKCS8 as fallback
|
||||
pkcs8Key, pkcs8Err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
|
||||
if pkcs8Err != nil {
|
||||
return nil, fmt.Errorf("parse CA private key (tried EC and PKCS8): %w", err)
|
||||
}
|
||||
return &inspect.TLSConfig{CA: cert, CAKey: pkcs8Key}, nil
|
||||
}
|
||||
|
||||
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
|
||||
}
|
||||
|
||||
// resolveInspectionCA sets the TLS config on the proxy config using priority:
|
||||
// 1. Local config file CA (InspectionCACertPath/InspectionCAKeyPath)
|
||||
// 2. Management-pushed CA (already parsed in toProxyConfig)
|
||||
// 3. Auto-generated self-signed CA (ephemeral, for testing)
|
||||
// Local always wins to prevent a compromised management server from injecting a CA.
|
||||
func (e *Engine) resolveInspectionCA(config *inspect.Config) {
|
||||
// 1. Local CA from config file or env vars
|
||||
certPath := e.config.InspectionCACertPath
|
||||
keyPath := e.config.InspectionCAKeyPath
|
||||
if certPath == "" {
|
||||
certPath = os.Getenv("NB_INSPECTION_CA_CERT")
|
||||
}
|
||||
if keyPath == "" {
|
||||
keyPath = os.Getenv("NB_INSPECTION_CA_KEY")
|
||||
}
|
||||
if certPath != "" && keyPath != "" {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Errorf("read local inspection CA cert %s: %v", certPath, err)
|
||||
return
|
||||
}
|
||||
keyPEM, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
log.Errorf("read local inspection CA key %s: %v", keyPath, err)
|
||||
return
|
||||
}
|
||||
tlsCfg, err := parseTLSConfig(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
log.Errorf("parse local inspection CA: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("inspect: using local CA from %s", certPath)
|
||||
config.TLS = tlsCfg
|
||||
return
|
||||
}
|
||||
|
||||
// 2. Management-pushed CA (already set by toProxyConfig)
|
||||
if config.TLS != nil {
|
||||
log.Infof("inspect: using management-pushed CA")
|
||||
return
|
||||
}
|
||||
|
||||
// 3. Auto-generate self-signed CA for testing / accept-cert UX
|
||||
tlsCfg, err := generateSelfSignedCA()
|
||||
if err != nil {
|
||||
log.Errorf("generate self-signed inspection CA: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof("inspect: using auto-generated self-signed CA (clients will see certificate warnings)")
|
||||
config.TLS = tlsCfg
|
||||
}
|
||||
|
||||
// generateSelfSignedCA creates an ephemeral ECDSA P-256 CA certificate.
|
||||
// Clients will see certificate warnings but can choose to accept.
|
||||
func generateSelfSignedCA() (*inspect.TLSConfig, error) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate CA key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial: %w", err)
|
||||
}
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"NetBird Transparent Proxy"},
|
||||
CommonName: "NetBird Inspection CA (auto-generated)",
|
||||
},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 0,
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create CA certificate: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certDER)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse generated CA certificate: %w", err)
|
||||
}
|
||||
|
||||
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
|
||||
}
|
||||
@@ -1,279 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/inspect"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func TestToProxyConfig_Basic(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Mode: mgmProto.TransparentProxyMode_TP_MODE_BUILTIN,
|
||||
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_ALLOW,
|
||||
RedirectSources: []string{
|
||||
"10.0.0.0/24",
|
||||
"192.168.1.0/24",
|
||||
},
|
||||
RedirectPorts: []uint32{80, 443},
|
||||
Rules: []*mgmProto.TransparentProxyRule{
|
||||
{
|
||||
Id: "block-evil",
|
||||
Domains: []string{"*.evil.com", "malware.example.com"},
|
||||
Action: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
|
||||
Priority: 1,
|
||||
},
|
||||
{
|
||||
Id: "inspect-internal",
|
||||
Domains: []string{"*.internal.corp"},
|
||||
Networks: []string{"10.1.0.0/16"},
|
||||
Ports: []uint32{443, 8443},
|
||||
Action: mgmProto.TransparentProxyAction_TP_ACTION_INSPECT,
|
||||
Priority: 10,
|
||||
},
|
||||
},
|
||||
ListenPort: 8443,
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, config.Enabled)
|
||||
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
|
||||
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
|
||||
|
||||
require.Len(t, config.RedirectSources, 2)
|
||||
assert.Equal(t, "10.0.0.0/24", config.RedirectSources[0].String())
|
||||
assert.Equal(t, "192.168.1.0/24", config.RedirectSources[1].String())
|
||||
|
||||
require.Len(t, config.RedirectPorts, 2)
|
||||
assert.Equal(t, uint16(80), config.RedirectPorts[0])
|
||||
assert.Equal(t, uint16(443), config.RedirectPorts[1])
|
||||
|
||||
require.Len(t, config.Rules, 2)
|
||||
|
||||
// Rule 1: block evil domains
|
||||
assert.Equal(t, "block-evil", string(config.Rules[0].ID))
|
||||
assert.Equal(t, inspect.ActionBlock, config.Rules[0].Action)
|
||||
assert.Equal(t, 1, config.Rules[0].Priority)
|
||||
require.Len(t, config.Rules[0].Domains, 2)
|
||||
assert.Equal(t, "*.evil.com", config.Rules[0].Domains[0].PunycodeString())
|
||||
assert.Equal(t, "malware.example.com", config.Rules[0].Domains[1].PunycodeString())
|
||||
|
||||
// Rule 2: inspect internal
|
||||
assert.Equal(t, "inspect-internal", string(config.Rules[1].ID))
|
||||
assert.Equal(t, inspect.ActionInspect, config.Rules[1].Action)
|
||||
assert.Equal(t, 10, config.Rules[1].Priority)
|
||||
require.Len(t, config.Rules[1].Networks, 1)
|
||||
assert.Equal(t, "10.1.0.0/16", config.Rules[1].Networks[0].String())
|
||||
require.Len(t, config.Rules[1].Ports, 2)
|
||||
|
||||
// Listen address
|
||||
assert.True(t, config.ListenAddr.IsValid())
|
||||
assert.Equal(t, uint16(8443), config.ListenAddr.Port())
|
||||
}
|
||||
|
||||
func TestToProxyConfig_ExternalMode(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Mode: mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL,
|
||||
ExternalProxyUrl: "http://proxy.corp:8080",
|
||||
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, inspect.ModeExternal, config.Mode)
|
||||
assert.Equal(t, inspect.ActionBlock, config.DefaultAction)
|
||||
require.NotNil(t, config.ExternalURL)
|
||||
assert.Equal(t, "http", config.ExternalURL.Scheme)
|
||||
assert.Equal(t, "proxy.corp:8080", config.ExternalURL.Host)
|
||||
}
|
||||
|
||||
func TestToProxyConfig_ICAP(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Icap: &mgmProto.TransparentProxyICAPConfig{
|
||||
ReqmodUrl: "icap://icap-server:1344/reqmod",
|
||||
RespmodUrl: "icap://icap-server:1344/respmod",
|
||||
MaxConnections: 16,
|
||||
},
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotNil(t, config.ICAP)
|
||||
assert.Equal(t, "icap", config.ICAP.ReqModURL.Scheme)
|
||||
assert.Equal(t, "icap-server:1344", config.ICAP.ReqModURL.Host)
|
||||
assert.Equal(t, "/reqmod", config.ICAP.ReqModURL.Path)
|
||||
assert.Equal(t, "/respmod", config.ICAP.RespModURL.Path)
|
||||
assert.Equal(t, 16, config.ICAP.MaxConnections)
|
||||
}
|
||||
|
||||
func TestToProxyConfig_Empty(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
config, err := toProxyConfig(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, config.Enabled)
|
||||
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
|
||||
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
|
||||
assert.Empty(t, config.RedirectSources)
|
||||
assert.Empty(t, config.RedirectPorts)
|
||||
assert.Empty(t, config.Rules)
|
||||
assert.Nil(t, config.ICAP)
|
||||
assert.Nil(t, config.TLS)
|
||||
assert.False(t, config.ListenAddr.IsValid())
|
||||
}
|
||||
|
||||
func TestToProxyConfig_InvalidSource(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
RedirectSources: []string{"not-a-cidr"},
|
||||
}
|
||||
|
||||
_, err := toProxyConfig(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "parse redirect source")
|
||||
}
|
||||
|
||||
func TestToProxyConfig_InvalidNetwork(t *testing.T) {
|
||||
cfg := &mgmProto.TransparentProxyConfig{
|
||||
Enabled: true,
|
||||
Rules: []*mgmProto.TransparentProxyRule{
|
||||
{
|
||||
Id: "bad",
|
||||
Networks: []string{"not-a-cidr"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := toProxyConfig(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "parse network")
|
||||
}
|
||||
|
||||
func TestToProxyAction(t *testing.T) {
|
||||
assert.Equal(t, inspect.ActionAllow, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_ALLOW))
|
||||
assert.Equal(t, inspect.ActionBlock, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_BLOCK))
|
||||
assert.Equal(t, inspect.ActionInspect, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_INSPECT))
|
||||
// Unknown defaults to allow
|
||||
assert.Equal(t, inspect.ActionAllow, toProxyAction(99))
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv4(t *testing.T) {
|
||||
// Build a minimal IPv4/UDP packet: 20-byte IPv4 header + 8-byte UDP header + payload
|
||||
packet := make([]byte, 20+8+4)
|
||||
|
||||
// IPv4 header: version=4, IHL=5 (20 bytes)
|
||||
packet[0] = 0x45
|
||||
// Protocol = UDP (17)
|
||||
packet[9] = 17
|
||||
// Source IP: 10.0.0.1
|
||||
packet[12], packet[13], packet[14], packet[15] = 10, 0, 0, 1
|
||||
// Dest IP: 192.168.1.1
|
||||
packet[16], packet[17], packet[18], packet[19] = 192, 168, 1, 1
|
||||
// UDP source port: 54321 (0xD431)
|
||||
packet[20] = 0xD4
|
||||
packet[21] = 0x31
|
||||
// UDP dest port: 443 (0x01BB)
|
||||
packet[22] = 0x01
|
||||
packet[23] = 0xBB
|
||||
// Payload
|
||||
packet[28] = 0xDE
|
||||
packet[29] = 0xAD
|
||||
packet[30] = 0xBE
|
||||
packet[31] = 0xEF
|
||||
|
||||
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.1", srcIP.String())
|
||||
assert.Equal(t, "192.168.1.1", dstIP.String())
|
||||
assert.Equal(t, uint16(443), dstPort)
|
||||
assert.Equal(t, []byte{0xDE, 0xAD, 0xBE, 0xEF}, payload)
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv6(t *testing.T) {
|
||||
// Build a minimal IPv6/UDP packet: 40-byte IPv6 header + 8-byte UDP header + payload
|
||||
packet := make([]byte, 40+8+4)
|
||||
|
||||
// Version = 6 (0x60 in high nibble)
|
||||
packet[0] = 0x60
|
||||
// Payload length: 8 (UDP header) + 4 (payload)
|
||||
packet[4] = 0
|
||||
packet[5] = 12
|
||||
// Next header: UDP (17)
|
||||
packet[6] = 17
|
||||
// Source: 2001:db8::1
|
||||
packet[8] = 0x20
|
||||
packet[9] = 0x01
|
||||
packet[10] = 0x0d
|
||||
packet[11] = 0xb8
|
||||
packet[23] = 0x01
|
||||
// Dest: 2001:db8::2
|
||||
packet[24] = 0x20
|
||||
packet[25] = 0x01
|
||||
packet[26] = 0x0d
|
||||
packet[27] = 0xb8
|
||||
packet[39] = 0x02
|
||||
// UDP source port: 54321 (0xD431)
|
||||
packet[40] = 0xD4
|
||||
packet[41] = 0x31
|
||||
// UDP dest port: 443 (0x01BB)
|
||||
packet[42] = 0x01
|
||||
packet[43] = 0xBB
|
||||
// Payload
|
||||
packet[48] = 0xCA
|
||||
packet[49] = 0xFE
|
||||
packet[50] = 0xBA
|
||||
packet[51] = 0xBE
|
||||
|
||||
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "2001:db8::1", srcIP.String())
|
||||
assert.Equal(t, "2001:db8::2", dstIP.String())
|
||||
assert.Equal(t, uint16(443), dstPort)
|
||||
assert.Equal(t, []byte{0xCA, 0xFE, 0xBA, 0xBE}, payload)
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_TooShort(t *testing.T) {
|
||||
_, _, _, _, ok := parseUDPPacket(nil)
|
||||
assert.False(t, ok)
|
||||
|
||||
_, _, _, _, ok = parseUDPPacket([]byte{0x45, 0x00})
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv6ExtensionHeader(t *testing.T) {
|
||||
// IPv6 with next header != UDP should be rejected
|
||||
packet := make([]byte, 48)
|
||||
packet[0] = 0x60
|
||||
packet[6] = 6 // TCP, not UDP
|
||||
_, _, _, _, ok := parseUDPPacket(packet)
|
||||
assert.False(t, ok, "should reject IPv6 packets with non-UDP next header")
|
||||
}
|
||||
|
||||
func TestParseUDPPacket_IPv4MappedIPv6(t *testing.T) {
|
||||
// IPv4 packet with normal addresses should Unmap correctly
|
||||
packet := make([]byte, 28)
|
||||
packet[0] = 0x45
|
||||
packet[9] = 17
|
||||
packet[12], packet[13], packet[14], packet[15] = 127, 0, 0, 1
|
||||
packet[16], packet[17], packet[18], packet[19] = 10, 0, 0, 1
|
||||
packet[22] = 0x01
|
||||
packet[23] = 0xBB
|
||||
|
||||
srcIP, dstIP, _, _, ok := parseUDPPacket(packet)
|
||||
require.True(t, ok)
|
||||
assert.True(t, srcIP.Is4(), "should be plain IPv4, not mapped")
|
||||
assert.True(t, dstIP.Is4(), "should be plain IPv4, not mapped")
|
||||
}
|
||||
247
client/internal/engine_vnc.go
Normal file
247
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
vncExternalPort uint16 = 5900
|
||||
vncInternalPort uint16 = 25900
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
// sshConf provides the JWT identity provider config (shared with SSH).
|
||||
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
// Update JWT config on existing server in case management sent new config.
|
||||
e.updateVNCServerJWT(sshConf)
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer(sshConf)
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector := newPlatformVNC()
|
||||
if capturer == nil || injector == nil {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector, "")
|
||||
if vncNeedsServiceMode() {
|
||||
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
// Configure VNC authentication.
|
||||
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||
log.Info("VNC: authentication disabled by config")
|
||||
srv.SetDisableAuth(true)
|
||||
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
srv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNCServerJWT configures the JWT validation for the VNC server using
|
||||
// the same JWT config as SSH (same identity provider).
|
||||
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||
vncSrv.SetDisableAuth(true)
|
||||
return
|
||||
}
|
||||
|
||||
protoJWT := sshConf.GetJwtConfig()
|
||||
if protoJWT == nil {
|
||||
return
|
||||
}
|
||||
|
||||
audiences := protoJWT.GetAudiences()
|
||||
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||
audiences = []string{protoJWT.GetAudience()}
|
||||
}
|
||||
|
||||
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audiences: audiences,
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
})
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if vncAuth == nil || e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
UserIDClaim: vncAuth.GetUserIDClaim(),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running.
|
||||
func (e *Engine) GetVNCServerStatus() bool {
|
||||
return e.vncSrv != nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
23
client/internal/engine_vnc_darwin.go
Normal file
23
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}
|
||||
}
|
||||
return capturer, injector
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
13
client/internal/engine_vnc_stub.go
Normal file
13
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
23
client/internal/engine_vnc_x11.go
Normal file
23
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||
capturer := vncserver.NewX11Poller("")
|
||||
injector, err := vncserver.NewX11InputInjector("")
|
||||
if err != nil {
|
||||
log.Debugf("VNC: X11 input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}
|
||||
}
|
||||
return capturer, injector
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -64,11 +64,13 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
@@ -97,9 +99,6 @@ type ConfigInput struct {
|
||||
LazyConnectionEnabled *bool
|
||||
|
||||
MTU *uint16
|
||||
|
||||
InspectionCACertPath string
|
||||
InspectionCAKeyPath string
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -117,11 +116,13 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
DisableVNCAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
|
||||
DisableClientRoutes bool
|
||||
@@ -174,13 +175,6 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// InspectionCACertPath is the path to a PEM CA certificate for transparent proxy MITM.
|
||||
// Local CA takes priority over management-pushed CA.
|
||||
InspectionCACertPath string
|
||||
|
||||
// InspectionCAKeyPath is the path to the PEM CA private key for transparent proxy MITM.
|
||||
InspectionCAKeyPath string
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -425,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
@@ -475,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
|
||||
if *input.DisableVNCAuth {
|
||||
log.Infof("disabling VNC authentication")
|
||||
} else {
|
||||
log.Infof("enabling VNC authentication")
|
||||
}
|
||||
config.DisableVNCAuth = input.DisableVNCAuth
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
@@ -613,17 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.InspectionCACertPath != "" && input.InspectionCACertPath != config.InspectionCACertPath {
|
||||
log.Infof("updating inspection CA cert path to %s", input.InspectionCACertPath)
|
||||
config.InspectionCACertPath = input.InspectionCACertPath
|
||||
updated = true
|
||||
}
|
||||
if input.InspectionCAKeyPath != "" && input.InspectionCAKeyPath != config.InspectionCAKeyPath {
|
||||
log.Infof("updating inspection CA key path to %s", input.InspectionCAKeyPath)
|
||||
config.InspectionCAKeyPath = input.InspectionCAKeyPath
|
||||
updated = true
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
||||
NetworkType: route.IPv4Network,
|
||||
}
|
||||
cr = append(cr, fakeIPRoute)
|
||||
m.notifier.SetFakeIPRoute(fakeIPRoute)
|
||||
}
|
||||
|
||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
type Notifier struct {
|
||||
initialRoutes []*route.Route
|
||||
currentRoutes []*route.Route
|
||||
fakeIPRoute *route.Route
|
||||
|
||||
listener listener.NetworkChangeListener
|
||||
listenerMux sync.Mutex
|
||||
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
n.listener = listener
|
||||
}
|
||||
|
||||
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
|
||||
// and a separate comparison set (without fake IP blocks) for diff detection.
|
||||
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
|
||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||
n.initialRoutes = filterStatic(initialRoutes)
|
||||
n.currentRoutes = filterStatic(routesForComparison)
|
||||
}
|
||||
|
||||
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
|
||||
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
|
||||
n.fakeIPRoute = r
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
var newRoutes []*route.Route
|
||||
for _, routes := range idMap {
|
||||
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
|
||||
}
|
||||
|
||||
allRoutes := slices.Clone(n.currentRoutes)
|
||||
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
|
||||
if n.fakeIPRoute != nil {
|
||||
allRoutes = append(allRoutes, n.fakeIPRoute)
|
||||
}
|
||||
|
||||
routeStrings := n.routesToStrings(allRoutes)
|
||||
sort.Strings(routeStrings)
|
||||
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
|
||||
}(n.listener)
|
||||
}
|
||||
|
||||
// extraInitialRoutes returns initialRoutes whose network prefix is absent
|
||||
// from currentRoutes (e.g. the fake IP block added at setup time).
|
||||
func (n *Notifier) extraInitialRoutes() []*route.Route {
|
||||
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
|
||||
for _, r := range n.currentRoutes {
|
||||
currentNets[r.Network] = struct{}{}
|
||||
}
|
||||
|
||||
var extra []*route.Route
|
||||
for _, r := range n.initialRoutes {
|
||||
if _, ok := currentNets[r.Network]; !ok {
|
||||
extra = append(extra, r)
|
||||
}
|
||||
}
|
||||
return extra
|
||||
}
|
||||
|
||||
func filterStatic(routes []*route.Route) []*route.Route {
|
||||
out := make([]*route.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
|
||||
@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// iOS doesn't care about initial routes
|
||||
}
|
||||
|
||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||
// Not used on iOS
|
||||
}
|
||||
|
||||
@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
// Not used on non-mobile platforms
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -209,6 +209,9 @@ message LoginRequest {
|
||||
optional bool enableSSHRemotePortForwarding = 37;
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
optional bool disableVNCAuth = 42;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -316,6 +319,10 @@ message GetConfigResponse {
|
||||
bool disableSSHAuth = 25;
|
||||
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
|
||||
bool disableVNCAuth = 29;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -394,6 +401,11 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -408,6 +420,7 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -677,6 +690,9 @@ message SetConfigRequest {
|
||||
optional bool enableSSHRemotePortForwarding = 32;
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
optional bool disableVNCAuth = 37;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
|
||||
@@ -366,6 +366,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -382,6 +383,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
if msg.DisableVNCAuth != nil {
|
||||
config.DisableVNCAuth = msg.DisableVNCAuth
|
||||
}
|
||||
if msg.SshJWTCacheTTL != nil {
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
@@ -1120,6 +1124,7 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1159,6 +1164,26 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &proto.VNCServerState{
|
||||
Enabled: engine.GetVNCServerStatus(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1500,6 +1525,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableSSHAuth = *cfg.DisableSSHAuth
|
||||
}
|
||||
|
||||
disableVNCAuth := false
|
||||
if cfg.DisableVNCAuth != nil {
|
||||
disableVNCAuth = *cfg.DisableVNCAuth
|
||||
}
|
||||
|
||||
sshJWTCacheTTL := int32(0)
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||
@@ -1514,6 +1544,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
@@ -1529,6 +1560,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
DisableVNCAuth: disableVNCAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -58,6 +58,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverVNCAllowed := true
|
||||
disableVNCAuth := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -82,6 +84,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerVNCAllowed: &serverVNCAllowed,
|
||||
DisableVNCAuth: &disableVNCAuth,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -125,6 +129,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||
require.NotNil(t, cfg.DisableVNCAuth)
|
||||
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -176,6 +184,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerVNCAllowed": true,
|
||||
"DisableVNCAuth": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -236,6 +246,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-vnc": "ServerVNCAllowed",
|
||||
"disable-vnc-auth": "DisableVNCAuth",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
|
||||
}
|
||||
}
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication.
|
||||
// This is the same approach OpenSSH for Windows uses for public key authentication.
|
||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
|
||||
@@ -507,27 +507,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
maxTokenAge = DefaultJWTMaxTokenAge
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
@@ -558,27 +538,7 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
return jwt.UserIDFromToken(token)
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
|
||||
@@ -130,6 +130,10 @@ type SSHServerStateOutput struct {
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -151,6 +155,7 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -171,6 +176,9 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
vncServerOverview := VNCServerStateOutput{
|
||||
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
|
||||
}
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
@@ -194,6 +202,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
VNCServerState: vncServerOverview,
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
@@ -524,6 +533,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
vncServerStatus := "Disabled"
|
||||
if o.VNCServerState.Enabled {
|
||||
vncServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
@@ -553,6 +567,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"VNC Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
@@ -570,6 +585,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
vncServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
|
||||
@@ -398,6 +398,9 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
},
|
||||
"vncServer":{
|
||||
"enabled":false
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
@@ -505,6 +508,8 @@ profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
vncServer:
|
||||
enabled: false
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -572,6 +577,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
@@ -596,6 +602,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -63,6 +63,7 @@ type Info struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -78,21 +79,27 @@ type Info struct {
|
||||
EnableSSHLocalPortForwarding bool
|
||||
EnableSSHRemotePortForwarding bool
|
||||
DisableSSHAuth bool
|
||||
DisableVNCAuth bool
|
||||
}
|
||||
|
||||
func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
serverVNCAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
disableSSHAuth *bool,
|
||||
disableVNCAuth *bool,
|
||||
) {
|
||||
i.RosenpassEnabled = rosenpassEnabled
|
||||
i.RosenpassPermissive = rosenpassPermissive
|
||||
if serverSSHAllowed != nil {
|
||||
i.ServerSSHAllowed = *serverSSHAllowed
|
||||
}
|
||||
if serverVNCAllowed != nil {
|
||||
i.ServerVNCAllowed = *serverVNCAllowed
|
||||
}
|
||||
|
||||
i.DisableClientRoutes = disableClientRoutes
|
||||
i.DisableServerRoutes = disableServerRoutes
|
||||
@@ -118,6 +125,9 @@ func (i *Info) SetFlags(
|
||||
if disableSSHAuth != nil {
|
||||
i.DisableSSHAuth = *disableSSHAuth
|
||||
}
|
||||
if disableVNCAuth != nil {
|
||||
i.DisableVNCAuth = *disableVNCAuth
|
||||
}
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
|
||||
474
client/vnc/server/agent_windows.go
Normal file
474
client/vnc/server/agent_windows.go
Normal file
@@ -0,0 +1,474 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
agentPort = "15900"
|
||||
|
||||
// agentTokenLen is the length of the random authentication token
|
||||
// used to verify that connections to the agent come from the service.
|
||||
agentTokenLen = 32
|
||||
|
||||
stillActive = 259
|
||||
|
||||
tokenPrimary = 1
|
||||
securityImpersonation = 2
|
||||
tokenSessionID = 12
|
||||
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
createNoWindow = 0x08000000
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||
|
||||
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||
|
||||
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||
)
|
||||
|
||||
// GetCurrentSessionID returns the session ID of the current process.
|
||||
func GetCurrentSessionID() uint32 {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_QUERY, &token); err != nil {
|
||||
return 0
|
||||
}
|
||||
defer token.Close()
|
||||
var id uint32
|
||||
var ret uint32
|
||||
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||
return id
|
||||
}
|
||||
|
||||
func getConsoleSessionID() uint32 {
|
||||
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||
return uint32(r)
|
||||
}
|
||||
|
||||
const (
|
||||
wtsActive = 0
|
||||
wtsConnected = 1
|
||||
wtsDisconnected = 4
|
||||
)
|
||||
|
||||
type wtsSessionInfo struct {
|
||||
SessionID uint32
|
||||
WinStationName [66]byte // actually *uint16, but we just need the struct size
|
||||
State uint32
|
||||
}
|
||||
|
||||
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||
// Prefers an active (logged-in, interactive) session over the console session.
|
||||
// This avoids kicking out an RDP user when the console is at the login screen.
|
||||
func getActiveSessionID() uint32 {
|
||||
var sessionInfo uintptr
|
||||
var count uint32
|
||||
|
||||
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
0, // reserved
|
||||
1, // version
|
||||
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||
uintptr(unsafe.Pointer(&count)),
|
||||
)
|
||||
if r == 0 || count == 0 {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
defer procWTSFreeMemory.Call(sessionInfo)
|
||||
|
||||
type wtsSession struct {
|
||||
SessionID uint32
|
||||
Station *uint16
|
||||
State uint32
|
||||
}
|
||||
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||
|
||||
// Find the first active session (not session 0, which is the services session).
|
||||
var bestID uint32
|
||||
found := false
|
||||
for _, s := range sessions {
|
||||
if s.SessionID == 0 {
|
||||
continue
|
||||
}
|
||||
if s.State == wtsActive {
|
||||
bestID = s.SessionID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
return bestID
|
||||
}
|
||||
|
||||
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||
var cur windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||
}
|
||||
defer cur.Close()
|
||||
|
||||
var dup windows.Token
|
||||
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||
}
|
||||
|
||||
sid := sessionID
|
||||
r, _, err := procSetTokenInformation.Call(
|
||||
uintptr(dup),
|
||||
uintptr(tokenSessionID),
|
||||
uintptr(unsafe.Pointer(&sid)),
|
||||
unsafe.Sizeof(sid),
|
||||
)
|
||||
if r == 0 {
|
||||
dup.Close()
|
||||
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||
}
|
||||
return dup, nil
|
||||
}
|
||||
|
||||
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
|
||||
|
||||
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||
// an extra null. Returns a new block pointer with the entry added.
|
||||
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
|
||||
entry := key + "=" + value
|
||||
|
||||
// Walk the existing block to find its total length.
|
||||
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||
var totalChars int
|
||||
for {
|
||||
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||
if ch == 0 {
|
||||
// Check for double-null terminator.
|
||||
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||
totalChars++
|
||||
if next == 0 {
|
||||
// End of block (don't count the final null yet, we'll rebuild).
|
||||
break
|
||||
}
|
||||
} else {
|
||||
totalChars++
|
||||
}
|
||||
}
|
||||
|
||||
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||
// New block: existing entries + new entry (null-terminated) + final null.
|
||||
newLen := totalChars + len(entryUTF16) + 1
|
||||
newBlock := make([]uint16, newLen)
|
||||
// Copy existing entries (up to but not including the final null).
|
||||
for i := range totalChars {
|
||||
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||
}
|
||||
copy(newBlock[totalChars:], entryUTF16)
|
||||
newBlock[newLen-1] = 0 // final null terminator
|
||||
|
||||
return uintptr(unsafe.Pointer(&newBlock[0]))
|
||||
}
|
||||
|
||||
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
|
||||
token, err := getSystemTokenForSession(sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var envBlock uintptr
|
||||
r, _, _ := procCreateEnvironmentBlock.Call(
|
||||
uintptr(unsafe.Pointer(&envBlock)),
|
||||
uintptr(token),
|
||||
0,
|
||||
)
|
||||
if r != 0 {
|
||||
defer procDestroyEnvironmentBlock.Call(envBlock)
|
||||
}
|
||||
|
||||
// Inject the auth token into the environment block so it doesn't appear
|
||||
// in the process command line (visible via tasklist/wmic).
|
||||
if r != 0 {
|
||||
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||
}
|
||||
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get executable path: %w", err)
|
||||
}
|
||||
|
||||
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
|
||||
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||
}
|
||||
|
||||
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||
// its output in the service process.
|
||||
var sa windows.SecurityAttributes
|
||||
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||
sa.InheritHandle = 1
|
||||
|
||||
var stderrRead, stderrWrite windows.Handle
|
||||
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
// The read end must NOT be inherited by the child.
|
||||
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||
|
||||
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||
si := windows.StartupInfo{
|
||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||
Desktop: desktop,
|
||||
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||
ShowWindow: 0,
|
||||
StdErr: stderrWrite,
|
||||
StdOutput: stderrWrite,
|
||||
}
|
||||
var pi windows.ProcessInformation
|
||||
|
||||
var envPtr *uint16
|
||||
if envBlock != 0 {
|
||||
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||
}
|
||||
|
||||
err = windows.CreateProcessAsUser(
|
||||
token, nil, cmdLineW,
|
||||
nil, nil, true, // inheritHandles=true for the pipe
|
||||
createUnicodeEnvironment|createNoWindow,
|
||||
envPtr, nil, &si, &pi,
|
||||
)
|
||||
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||
windows.CloseHandle(stderrWrite)
|
||||
if err != nil {
|
||||
windows.CloseHandle(stderrRead)
|
||||
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||
}
|
||||
windows.CloseHandle(pi.Thread)
|
||||
|
||||
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||
go relogAgentOutput(stderrRead)
|
||||
|
||||
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
|
||||
return pi.Process, nil
|
||||
}
|
||||
|
||||
// sessionManager monitors the active console session and ensures a VNC agent
|
||||
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||
// connect/disconnect), it kills the old agent and spawns a new one.
|
||||
type sessionManager struct {
|
||||
port string
|
||||
mu sync.Mutex
|
||||
agentProc windows.Handle
|
||||
sessionID uint32
|
||||
authToken string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newSessionManager(port string) *sessionManager {
|
||||
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
|
||||
}
|
||||
|
||||
// generateAuthToken creates a new random hex token for agent authentication.
|
||||
func generateAuthToken() string {
|
||||
b := make([]byte, agentTokenLen)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
log.Warnf("generate agent auth token: %v", err)
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// AuthToken returns the current agent authentication token.
|
||||
func (m *sessionManager) AuthToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.authToken
|
||||
}
|
||||
|
||||
// Stop signals the session manager to exit its polling loop.
|
||||
func (m *sessionManager) Stop() {
|
||||
select {
|
||||
case <-m.done:
|
||||
default:
|
||||
close(m.done)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sessionManager) run() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
sid := getActiveSessionID()
|
||||
|
||||
m.mu.Lock()
|
||||
if sid != m.sessionID {
|
||||
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||
m.killAgent()
|
||||
m.sessionID = sid
|
||||
}
|
||||
|
||||
if m.agentProc != 0 {
|
||||
var code uint32
|
||||
_ = windows.GetExitCodeProcess(m.agentProc, &code)
|
||||
if code != stillActive {
|
||||
log.Infof("agent exited (code=%d), respawning", code)
|
||||
windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
}
|
||||
}
|
||||
|
||||
if m.agentProc == 0 && sid != 0xFFFFFFFF {
|
||||
m.authToken = generateAuthToken()
|
||||
h, err := spawnAgentInSession(sid, m.port, m.authToken)
|
||||
if err != nil {
|
||||
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||
m.authToken = ""
|
||||
} else {
|
||||
m.agentProc = h
|
||||
}
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-m.done:
|
||||
m.mu.Lock()
|
||||
m.killAgent()
|
||||
m.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *sessionManager) killAgent() {
|
||||
if m.agentProc != 0 {
|
||||
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||
windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
log.Info("killed old agent")
|
||||
}
|
||||
}
|
||||
|
||||
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
|
||||
// relogs them at the correct level with the service's formatter.
|
||||
func relogAgentOutput(pipe windows.Handle) {
|
||||
defer windows.CloseHandle(pipe)
|
||||
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||
defer f.Close()
|
||||
|
||||
entry := log.WithField("component", "vnc-agent")
|
||||
dec := json.NewDecoder(f)
|
||||
for dec.More() {
|
||||
var m map[string]any
|
||||
if err := dec.Decode(&m); err != nil {
|
||||
break
|
||||
}
|
||||
msg, _ := m["msg"].(string)
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Forward extra fields from the agent (skip standard logrus fields).
|
||||
// Remap "caller" to "source" so it doesn't conflict with logrus internals
|
||||
// but still shows the original file/line from the agent process.
|
||||
fields := make(log.Fields)
|
||||
for k, v := range m {
|
||||
switch k {
|
||||
case "msg", "level", "time", "func":
|
||||
continue
|
||||
case "caller":
|
||||
fields["source"] = v
|
||||
default:
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
e := entry.WithFields(fields)
|
||||
|
||||
switch m["level"] {
|
||||
case "error":
|
||||
e.Error(msg)
|
||||
case "warning":
|
||||
e.Warn(msg)
|
||||
case "debug":
|
||||
e.Debug(msg)
|
||||
case "trace":
|
||||
e.Trace(msg)
|
||||
default:
|
||||
e.Info(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyToAgent connects to the agent, sends the auth token, then proxies
|
||||
// the VNC client connection bidirectionally.
|
||||
func proxyToAgent(client net.Conn, port string, authToken string) {
|
||||
defer client.Close()
|
||||
|
||||
addr := "127.0.0.1:" + port
|
||||
var agentConn net.Conn
|
||||
var err error
|
||||
for range 50 {
|
||||
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
|
||||
return
|
||||
}
|
||||
defer agentConn.Close()
|
||||
|
||||
// Send the auth token so the agent can verify this connection
|
||||
// comes from the trusted service process.
|
||||
tokenBytes, _ := hex.DecodeString(authToken)
|
||||
if _, err := agentConn.Write(tokenBytes); err != nil {
|
||||
log.Warnf("send auth token to agent: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||
|
||||
done := make(chan struct{}, 2)
|
||||
cp := func(label string, dst, src net.Conn) {
|
||||
n, err := io.Copy(dst, src)
|
||||
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||
done <- struct{}{}
|
||||
}
|
||||
go cp("client→agent", agentConn, client)
|
||||
go cp("agent→client", client, agentConn)
|
||||
<-done
|
||||
}
|
||||
274
client/vnc/server/capture_darwin.go
Normal file
274
client/vnc/server/capture_darwin.go
Normal file
@@ -0,0 +1,274 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var darwinCaptureOnce sync.Once
|
||||
|
||||
var (
|
||||
cgMainDisplayID func() uint32
|
||||
cgDisplayPixelsWide func(uint32) uintptr
|
||||
cgDisplayPixelsHigh func(uint32) uintptr
|
||||
cgDisplayCreateImage func(uint32) uintptr
|
||||
cgImageGetWidth func(uintptr) uintptr
|
||||
cgImageGetHeight func(uintptr) uintptr
|
||||
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||
cgImageGetDataProvider func(uintptr) uintptr
|
||||
cgDataProviderCopyData func(uintptr) uintptr
|
||||
cgImageRelease func(uintptr)
|
||||
cfDataGetLength func(uintptr) int64
|
||||
cfDataGetBytePtr func(uintptr) uintptr
|
||||
cfRelease func(uintptr)
|
||||
cgPreflightScreenCaptureAccess func() bool
|
||||
cgRequestScreenCaptureAccess func() bool
|
||||
darwinCaptureReady bool
|
||||
)
|
||||
|
||||
func initDarwinCapture() {
|
||||
darwinCaptureOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
|
||||
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
|
||||
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||
}
|
||||
|
||||
darwinCaptureReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// CGCapturer captures the macOS main display using Core Graphics.
|
||||
type CGCapturer struct {
|
||||
displayID uint32
|
||||
w, h int
|
||||
}
|
||||
|
||||
// NewCGCapturer creates a screen capturer for the main display.
|
||||
func NewCGCapturer() (*CGCapturer, error) {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available")
|
||||
}
|
||||
|
||||
// Request Screen Recording permission (shows system dialog on macOS 11+).
|
||||
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
log.Warn("Screen Recording permission not granted. " +
|
||||
"Grant in System Settings > Privacy & Security > Screen Recording, then restart.")
|
||||
}
|
||||
|
||||
displayID := cgMainDisplayID()
|
||||
w := int(cgDisplayPixelsWide(displayID))
|
||||
h := int(cgDisplayPixelsHigh(displayID))
|
||||
if w == 0 || h == 0 {
|
||||
return nil, fmt.Errorf("display dimensions are zero")
|
||||
}
|
||||
|
||||
log.Infof("macOS capturer ready: %dx%d (display=%d)", w, h, displayID)
|
||||
return &CGCapturer{displayID: displayID, w: w, h: h}, nil
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *CGCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *CGCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, fmt.Errorf("empty image data")
|
||||
}
|
||||
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
bytesPerPixel := bpp / 8
|
||||
for row := 0; row < h; row++ {
|
||||
srcOff := row * bytesPerRow
|
||||
dstOff := row * img.Stride
|
||||
for col := 0; col < w; col++ {
|
||||
si := srcOff + col*bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
img.Pix[di+0] = src[si+2] // R (from BGRA)
|
||||
img.Pix[di+1] = src[si+1] // G
|
||||
img.Pix[di+2] = src[si+0] // B
|
||||
img.Pix[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// MacPoller wraps CGCapturer in a continuous capture loop.
|
||||
type MacPoller struct {
|
||||
mu sync.Mutex
|
||||
frame *image.RGBA
|
||||
w, h int
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewMacPoller creates a capturer that continuously grabs the macOS display.
|
||||
func NewMacPoller() *MacPoller {
|
||||
p := &MacPoller{done: make(chan struct{})}
|
||||
go p.loop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Close stops the capture loop.
|
||||
func (p *MacPoller) Close() {
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (p *MacPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height.
|
||||
func (p *MacPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture returns the most recent frame.
|
||||
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
img := p.frame
|
||||
p.mu.Unlock()
|
||||
if img != nil {
|
||||
return img, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no frame available yet")
|
||||
}
|
||||
|
||||
func (p *MacPoller) loop() {
|
||||
var capturer *CGCapturer
|
||||
var initFails int
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if capturer == nil {
|
||||
var err error
|
||||
capturer, err = NewCGCapturer()
|
||||
if err != nil {
|
||||
initFails++
|
||||
if initFails <= maxCapturerRetries {
|
||||
log.Debugf("macOS capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Warnf("macOS capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
|
||||
return
|
||||
}
|
||||
initFails = 0
|
||||
p.mu.Lock()
|
||||
p.w, p.h = capturer.Width(), capturer.Height()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
img, err := capturer.Capture()
|
||||
if err != nil {
|
||||
log.Debugf("macOS capture: %v", err)
|
||||
capturer = nil
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.frame = img
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||
99
client/vnc/server/capture_dxgi_windows.go
Normal file
99
client/vnc/server/capture_dxgi_windows.go
Normal file
@@ -0,0 +1,99 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/kirides/go-d3d/d3d11"
|
||||
"github.com/kirides/go-d3d/outputduplication"
|
||||
)
|
||||
|
||||
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||
// Only works from the interactive user session, not Session 0.
|
||||
//
|
||||
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||
// output buffer and hand it out. Alternating between two output buffers
|
||||
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||
type dxgiCapturer struct {
|
||||
dup *outputduplication.OutputDuplicator
|
||||
device *d3d11.ID3D11Device
|
||||
ctx *d3d11.ID3D11DeviceContext
|
||||
img *image.RGBA
|
||||
out [2]*image.RGBA
|
||||
outIdx int
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||
}
|
||||
|
||||
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||
if err != nil {
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||
}
|
||||
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
dup.Release()
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
|
||||
rect := image.Rect(0, 0, w, h)
|
||||
c := &dxgiCapturer{
|
||||
dup: dup,
|
||||
device: device,
|
||||
ctx: deviceCtx,
|
||||
img: image.NewRGBA(rect),
|
||||
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||
width: w,
|
||||
height: h,
|
||||
}
|
||||
|
||||
// Grab the initial frame with a longer timeout to ensure we have
|
||||
// a valid image before returning.
|
||||
_ = dup.GetImage(c.img, 2000)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||
err := c.dup.GetImage(c.img, 100)
|
||||
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||
out := c.out[c.outIdx]
|
||||
c.outIdx ^= 1
|
||||
copy(out.Pix, c.img.Pix)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) close() {
|
||||
if c.dup != nil {
|
||||
c.dup.Release()
|
||||
c.dup = nil
|
||||
}
|
||||
if c.ctx != nil {
|
||||
c.ctx.Release()
|
||||
c.ctx = nil
|
||||
}
|
||||
if c.device != nil {
|
||||
c.device.Release()
|
||||
c.device = nil
|
||||
}
|
||||
}
|
||||
461
client/vnc/server/capture_windows.go
Normal file
461
client/vnc/server/capture_windows.go
Normal file
@@ -0,0 +1,461 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||
|
||||
procGetDC = user32.NewProc("GetDC")
|
||||
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||
procSelectObject = gdi32.NewProc("SelectObject")
|
||||
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||
procBitBlt = gdi32.NewProc("BitBlt")
|
||||
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||
|
||||
// Desktop switching for service/Session 0 capture.
|
||||
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||
)
|
||||
|
||||
const uoiName = 2
|
||||
|
||||
const (
|
||||
smCxScreen = 0
|
||||
smCyScreen = 1
|
||||
srccopy = 0x00CC0020
|
||||
dibRgbColors = 0
|
||||
)
|
||||
|
||||
type bitmapInfoHeader struct {
|
||||
Size uint32
|
||||
Width int32
|
||||
Height int32
|
||||
Planes uint16
|
||||
BitCount uint16
|
||||
Compression uint32
|
||||
SizeImage uint32
|
||||
XPelsPerMeter int32
|
||||
YPelsPerMeter int32
|
||||
ClrUsed uint32
|
||||
ClrImportant uint32
|
||||
}
|
||||
|
||||
type bitmapInfo struct {
|
||||
Header bitmapInfoHeader
|
||||
}
|
||||
|
||||
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||
// the interactive window station. This is required for a SYSTEM service in
|
||||
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||
func setupInteractiveWindowStation() error {
|
||||
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||
}
|
||||
hWinSta, _, err := procOpenWindowStation.Call(
|
||||
uintptr(unsafe.Pointer(name)),
|
||||
0,
|
||||
uintptr(windows.MAXIMUM_ALLOWED),
|
||||
)
|
||||
if hWinSta == 0 {
|
||||
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||
}
|
||||
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||
if r == 0 {
|
||||
procCloseWindowStation.Call(hWinSta)
|
||||
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||
}
|
||||
log.Info("process window station set to WinSta0 (interactive)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func screenSize() (int, int) {
|
||||
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||
return int(w), int(h)
|
||||
}
|
||||
|
||||
func getDesktopName(hDesk uintptr) string {
|
||||
var buf [256]uint16
|
||||
var needed uint32
|
||||
procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||
uintptr(unsafe.Pointer(&needed)))
|
||||
return windows.UTF16ToString(buf[:])
|
||||
}
|
||||
|
||||
// switchToInputDesktop opens the desktop currently receiving user input
|
||||
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||
func switchToInputDesktop() (bool, string) {
|
||||
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||
if hDesk == 0 {
|
||||
return false, ""
|
||||
}
|
||||
name := getDesktopName(hDesk)
|
||||
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||
procCloseDesktop.Call(hDesk)
|
||||
return ret != 0, name
|
||||
}
|
||||
|
||||
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||
type gdiCapturer struct {
|
||||
mu sync.Mutex
|
||||
width int
|
||||
height int
|
||||
|
||||
// Pre-allocated GDI resources, reused across captures.
|
||||
memDC uintptr
|
||||
bmp uintptr
|
||||
bits uintptr
|
||||
}
|
||||
|
||||
func newGDICapturer() (*gdiCapturer, error) {
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
c := &gdiCapturer{width: w, height: h}
|
||||
if err := c.allocGDI(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||
func (c *gdiCapturer) allocGDI() error {
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer procReleaseDC.Call(0, screenDC)
|
||||
|
||||
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||
if memDC == 0 {
|
||||
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||
}
|
||||
|
||||
bi := bitmapInfo{
|
||||
Header: bitmapInfoHeader{
|
||||
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||
Width: int32(c.width),
|
||||
Height: -int32(c.height), // negative = top-down DIB
|
||||
Planes: 1,
|
||||
BitCount: 32,
|
||||
},
|
||||
}
|
||||
|
||||
var bits uintptr
|
||||
bmp, _, _ := procCreateDIBSection.Call(
|
||||
screenDC,
|
||||
uintptr(unsafe.Pointer(&bi)),
|
||||
dibRgbColors,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
0, 0,
|
||||
)
|
||||
if bmp == 0 || bits == 0 {
|
||||
procDeleteDC.Call(memDC)
|
||||
return fmt.Errorf("CreateDIBSection returned 0")
|
||||
}
|
||||
|
||||
procSelectObject.Call(memDC, bmp)
|
||||
|
||||
c.memDC = memDC
|
||||
c.bmp = bmp
|
||||
c.bits = bits
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||
|
||||
// freeGDI releases pre-allocated GDI resources.
|
||||
func (c *gdiCapturer) freeGDI() {
|
||||
if c.bmp != 0 {
|
||||
procDeleteObject.Call(c.bmp)
|
||||
c.bmp = 0
|
||||
}
|
||||
if c.memDC != 0 {
|
||||
procDeleteDC.Call(c.memDC)
|
||||
c.memDC = 0
|
||||
}
|
||||
c.bits = 0
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.memDC == 0 {
|
||||
return nil, fmt.Errorf("GDI resources not allocated")
|
||||
}
|
||||
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return nil, fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer procReleaseDC.Call(0, screenDC)
|
||||
|
||||
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||
screenDC, 0, 0, srccopy)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("BitBlt returned 0")
|
||||
}
|
||||
|
||||
n := c.width * c.height * 4
|
||||
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||
|
||||
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||
// per pixel instead of three separate byte assignments).
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||
pix := img.Pix
|
||||
copy(pix, raw)
|
||||
swizzleBGRAtoRGBA(pix)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||
// captures frames, which are retrieved by the VNC session on demand.
|
||||
// Capture pauses automatically when no clients are connected.
|
||||
type DesktopCapturer struct {
|
||||
mu sync.Mutex
|
||||
frame *image.RGBA
|
||||
w, h int
|
||||
|
||||
// clients tracks the number of active VNC sessions. When zero, the
|
||||
// capture loop idles instead of grabbing frames.
|
||||
clients atomic.Int32
|
||||
|
||||
// wake is signaled when a client connects and the loop should resume.
|
||||
wake chan struct{}
|
||||
// done is closed when Close is called, terminating the capture loop.
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
|
||||
func NewDesktopCapturer() *DesktopCapturer {
|
||||
c := &DesktopCapturer{
|
||||
wake: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go c.loop()
|
||||
return c
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count, resuming capture if needed.
|
||||
func (c *DesktopCapturer) ClientConnect() {
|
||||
c.clients.Add(1)
|
||||
select {
|
||||
case c.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count.
|
||||
func (c *DesktopCapturer) ClientDisconnect() {
|
||||
c.clients.Add(-1)
|
||||
}
|
||||
|
||||
// Close stops the capture loop and releases resources.
|
||||
func (c *DesktopCapturer) Close() {
|
||||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the current screen width.
|
||||
func (c *DesktopCapturer) Width() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.w
|
||||
}
|
||||
|
||||
// Height returns the current screen height.
|
||||
func (c *DesktopCapturer) Height() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.h
|
||||
}
|
||||
|
||||
// Capture returns the most recent desktop frame.
|
||||
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
img := c.frame
|
||||
c.mu.Unlock()
|
||||
if img != nil {
|
||||
return img, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no frame available yet")
|
||||
}
|
||||
|
||||
// waitForClient blocks until a client connects or the capturer is closed.
|
||||
func (c *DesktopCapturer) waitForClient() bool {
|
||||
if c.clients.Load() > 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-c.wake:
|
||||
return true
|
||||
case <-c.done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *DesktopCapturer) loop() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
// When running as a Windows service (Session 0), we need to attach to the
|
||||
// interactive window station before OpenInputDesktop will succeed.
|
||||
if err := setupInteractiveWindowStation(); err != nil {
|
||||
log.Warnf("attach to interactive window station: %v", err)
|
||||
}
|
||||
|
||||
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
|
||||
defer frameTicker.Stop()
|
||||
|
||||
retryTimer := time.NewTimer(0)
|
||||
retryTimer.Stop()
|
||||
defer retryTimer.Stop()
|
||||
|
||||
type frameCapturer interface {
|
||||
capture() (*image.RGBA, error)
|
||||
close()
|
||||
}
|
||||
|
||||
var cap frameCapturer
|
||||
var desktopFails int
|
||||
var lastDesktop string
|
||||
|
||||
createCapturer := func() (frameCapturer, error) {
|
||||
dc, err := newDXGICapturer()
|
||||
if err == nil {
|
||||
log.Info("using DXGI Desktop Duplication for capture")
|
||||
return dc, nil
|
||||
}
|
||||
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
|
||||
gc, err := newGDICapturer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("using GDI BitBlt for capture")
|
||||
return gc, nil
|
||||
}
|
||||
|
||||
for {
|
||||
if !c.waitForClient() {
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No clients: release the capturer and wait.
|
||||
if c.clients.Load() <= 0 {
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
cap = nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
ok, desk := switchToInputDesktop()
|
||||
if !ok {
|
||||
desktopFails++
|
||||
if desktopFails == 1 || desktopFails%100 == 0 {
|
||||
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
|
||||
}
|
||||
retryTimer.Reset(100 * time.Millisecond)
|
||||
select {
|
||||
case <-retryTimer.C:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if desktopFails > 0 {
|
||||
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
|
||||
desktopFails = 0
|
||||
}
|
||||
if desk != lastDesktop {
|
||||
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
|
||||
lastDesktop = desk
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
}
|
||||
cap = nil
|
||||
}
|
||||
|
||||
if cap == nil {
|
||||
fc, err := createCapturer()
|
||||
if err != nil {
|
||||
log.Warnf("create capturer: %v", err)
|
||||
retryTimer.Reset(500 * time.Millisecond)
|
||||
select {
|
||||
case <-retryTimer.C:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
cap = fc
|
||||
w, h := screenSize()
|
||||
c.mu.Lock()
|
||||
c.w, c.h = w, h
|
||||
c.mu.Unlock()
|
||||
log.Infof("screen capturer ready: %dx%d", w, h)
|
||||
}
|
||||
|
||||
img, err := cap.capture()
|
||||
if err != nil {
|
||||
log.Debugf("capture: %v", err)
|
||||
cap.close()
|
||||
cap = nil
|
||||
retryTimer.Reset(100 * time.Millisecond)
|
||||
select {
|
||||
case <-retryTimer.C:
|
||||
case <-c.done:
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.frame = img
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-frameTicker.C:
|
||||
case <-c.done:
|
||||
if cap != nil {
|
||||
cap.close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
385
client/vnc/server/capture_x11.go
Normal file
385
client/vnc/server/capture_x11.go
Normal file
@@ -0,0 +1,385 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
)
|
||||
|
||||
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||
type X11Capturer struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
screen *xproto.ScreenInfo
|
||||
w, h int
|
||||
shmID int
|
||||
shmAddr []byte
|
||||
shmSeg uint32 // shm.Seg
|
||||
useSHM bool
|
||||
}
|
||||
|
||||
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||
// environment variables if needed. This is required when running as a system
|
||||
// service where these vars aren't set.
|
||||
func detectX11Display() {
|
||||
if os.Getenv("DISPLAY") != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||
if detectX11FromProc() {
|
||||
return
|
||||
}
|
||||
if detectX11FromSockets() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||
func detectX11FromProc() bool {
|
||||
entries, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||
setDisplayEnv(display, auth)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||
func detectX11FromSockets() bool {
|
||||
entries, err := os.ReadDir("/tmp/.X11-unix")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Find the lowest display number.
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if len(name) < 2 || name[0] != 'X' {
|
||||
continue
|
||||
}
|
||||
display := ":" + name[1:]
|
||||
os.Setenv("DISPLAY", display)
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||
|
||||
// Try to find -auth from ps output.
|
||||
if auth := findXorgAuthFromPS(); auth != "" {
|
||||
os.Setenv("XAUTHORITY", auth)
|
||||
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||
func findXorgAuthFromPS() string {
|
||||
out, err := exec.Command("ps", "auxww").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
for i, f := range fields {
|
||||
if f == "-auth" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseXorgArgs(args []string) (display, auth string) {
|
||||
if len(args) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
base := args[0]
|
||||
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||
return "", ""
|
||||
}
|
||||
for i, arg := range args[1:] {
|
||||
if len(arg) > 0 && arg[0] == ':' {
|
||||
display = arg
|
||||
}
|
||||
if arg == "-auth" && i+2 < len(args) {
|
||||
auth = args[i+2]
|
||||
}
|
||||
}
|
||||
return display, auth
|
||||
}
|
||||
|
||||
func setDisplayEnv(display, auth string) {
|
||||
os.Setenv("DISPLAY", display)
|
||||
log.Infof("auto-detected DISPLAY=%s", display)
|
||||
if auth != "" {
|
||||
os.Setenv("XAUTHORITY", auth)
|
||||
log.Infof("auto-detected XAUTHORITY=%s", auth)
|
||||
}
|
||||
}
|
||||
|
||||
func splitCmdline(data []byte) []string {
|
||||
var args []string
|
||||
for _, b := range splitNull(data) {
|
||||
if len(b) > 0 {
|
||||
args = append(args, string(b))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func splitNull(data []byte) [][]byte {
|
||||
var parts [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == 0 {
|
||||
parts = append(parts, data[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
parts = append(parts, data[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||
func NewX11Capturer(display string) (*X11Capturer, error) {
|
||||
detectX11Display()
|
||||
|
||||
if display == "" {
|
||||
display = os.Getenv("DISPLAY")
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
c := &X11Capturer{
|
||||
conn: conn,
|
||||
screen: &screen,
|
||||
w: int(screen.WidthInPixels),
|
||||
h: int(screen.HeightInPixels),
|
||||
}
|
||||
|
||||
if err := c.initSHM(); err != nil {
|
||||
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||
// the capturer falls back to GetImage.
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *X11Capturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *X11Capturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.useSHM {
|
||||
return c.captureSHM()
|
||||
}
|
||||
return c.captureGetImage()
|
||||
}
|
||||
|
||||
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
data := reply.Data
|
||||
n := c.w * c.h * 4
|
||||
if len(data) < n {
|
||||
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 4 {
|
||||
img.Pix[i+0] = data[i+2] // R
|
||||
img.Pix[i+1] = data[i+1] // G
|
||||
img.Pix[i+2] = data[i+0] // B
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (c *X11Capturer) Close() {
|
||||
c.closeSHM()
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
|
||||
// DesktopCapturer pattern from Windows.
|
||||
type X11Poller struct {
|
||||
mu sync.Mutex
|
||||
frame *image.RGBA
|
||||
w, h int
|
||||
display string
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewX11Poller creates a capturer that continuously grabs the X11 display.
|
||||
func NewX11Poller(display string) *X11Poller {
|
||||
p := &X11Poller{
|
||||
display: display,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
go p.loop()
|
||||
return p
|
||||
}
|
||||
|
||||
// Close stops the capture loop.
|
||||
func (p *X11Poller) Close() {
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (p *X11Poller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height.
|
||||
func (p *X11Poller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture returns the most recent frame.
|
||||
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
img := p.frame
|
||||
p.mu.Unlock()
|
||||
if img != nil {
|
||||
return img, nil
|
||||
}
|
||||
return nil, fmt.Errorf("no frame available yet")
|
||||
}
|
||||
|
||||
func (p *X11Poller) loop() {
|
||||
var capturer *X11Capturer
|
||||
var initFails int
|
||||
|
||||
defer func() {
|
||||
if capturer != nil {
|
||||
capturer.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if capturer == nil {
|
||||
var err error
|
||||
capturer, err = NewX11Capturer(p.display)
|
||||
if err != nil {
|
||||
initFails++
|
||||
if initFails <= maxCapturerRetries {
|
||||
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
continue
|
||||
}
|
||||
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
|
||||
return
|
||||
}
|
||||
initFails = 0
|
||||
p.mu.Lock()
|
||||
p.w, p.h = capturer.Width(), capturer.Height()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
img, err := capturer.Capture()
|
||||
if err != nil {
|
||||
log.Debugf("X11 capture: %v", err)
|
||||
capturer.Close()
|
||||
capturer = nil
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.frame = img
|
||||
p.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-p.done:
|
||||
return
|
||||
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||
}
|
||||
}
|
||||
}
|
||||
78
client/vnc/server/capture_x11_shm_linux.go
Normal file
78
client/vnc/server/capture_x11_shm_linux.go
Normal file
@@ -0,0 +1,78 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/jezek/xgb/shm"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
if err := shm.Init(c.conn); err != nil {
|
||||
return fmt.Errorf("init SHM extension: %w", err)
|
||||
}
|
||||
|
||||
size := c.w * c.h * 4
|
||||
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shmget: %w", err)
|
||||
}
|
||||
|
||||
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||
if err != nil {
|
||||
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||
return fmt.Errorf("shmat: %w", err)
|
||||
}
|
||||
|
||||
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||
|
||||
seg, err := shm.NewSegId(c.conn)
|
||||
if err != nil {
|
||||
unix.SysvShmDetach(addr)
|
||||
return fmt.Errorf("new SHM seg: %w", err)
|
||||
}
|
||||
|
||||
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||
unix.SysvShmDetach(addr)
|
||||
return fmt.Errorf("SHM attach to X: %w", err)
|
||||
}
|
||||
|
||||
c.shmID = id
|
||||
c.shmAddr = addr
|
||||
c.shmSeg = uint32(seg)
|
||||
c.useSHM = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||
|
||||
_, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SHM GetImage: %w", err)
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
n := c.w * c.h * 4
|
||||
|
||||
for i := 0; i < n; i += 4 {
|
||||
img.Pix[i+0] = c.shmAddr[i+2] // R
|
||||
img.Pix[i+1] = c.shmAddr[i+1] // G
|
||||
img.Pix[i+2] = c.shmAddr[i+0] // B
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
if c.useSHM {
|
||||
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||
unix.SysvShmDetach(c.shmAddr)
|
||||
}
|
||||
}
|
||||
18
client/vnc/server/capture_x11_shm_stub.go
Normal file
18
client/vnc/server/capture_x11_shm_stub.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
return fmt.Errorf("SysV SHM not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {}
|
||||
403
client/vnc/server/input_darwin.go
Normal file
403
client/vnc/server/input_darwin.go
Normal file
@@ -0,0 +1,403 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Core Graphics event constants.
|
||||
const (
|
||||
kCGEventSourceStateCombinedSessionState int32 = 0
|
||||
|
||||
kCGEventLeftMouseDown int32 = 1
|
||||
kCGEventLeftMouseUp int32 = 2
|
||||
kCGEventRightMouseDown int32 = 3
|
||||
kCGEventRightMouseUp int32 = 4
|
||||
kCGEventMouseMoved int32 = 5
|
||||
kCGEventLeftMouseDragged int32 = 6
|
||||
kCGEventRightMouseDragged int32 = 7
|
||||
kCGEventKeyDown int32 = 10
|
||||
kCGEventKeyUp int32 = 11
|
||||
kCGEventOtherMouseDown int32 = 25
|
||||
kCGEventOtherMouseUp int32 = 26
|
||||
|
||||
kCGMouseButtonLeft int32 = 0
|
||||
kCGMouseButtonRight int32 = 1
|
||||
kCGMouseButtonCenter int32 = 2
|
||||
|
||||
kCGHIDEventTap int32 = 0
|
||||
)
|
||||
|
||||
var darwinInputOnce sync.Once
|
||||
|
||||
var (
|
||||
cgEventSourceCreate func(int32) uintptr
|
||||
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
|
||||
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
|
||||
// purego can't handle array/struct types but individual float64s work.
|
||||
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
|
||||
cgEventPost func(int32, uintptr)
|
||||
|
||||
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
|
||||
cgEventCreateScrollWheelEventAddr uintptr
|
||||
|
||||
darwinInputReady bool
|
||||
darwinEventSource uintptr
|
||||
)
|
||||
|
||||
func initDarwinInput() {
|
||||
darwinInputOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics for input: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
|
||||
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
|
||||
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
|
||||
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
|
||||
|
||||
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
|
||||
if err == nil {
|
||||
cgEventCreateScrollWheelEventAddr = sym
|
||||
}
|
||||
|
||||
darwinInputReady = true
|
||||
})
|
||||
}
|
||||
|
||||
func ensureEventSource() uintptr {
|
||||
if darwinEventSource != 0 {
|
||||
return darwinEventSource
|
||||
}
|
||||
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
|
||||
return darwinEventSource
|
||||
}
|
||||
|
||||
// MacInputInjector injects keyboard and mouse events via Core Graphics.
|
||||
type MacInputInjector struct {
|
||||
lastButtons uint8
|
||||
pbcopyPath string
|
||||
pbpastePath string
|
||||
}
|
||||
|
||||
// NewMacInputInjector creates a macOS input injector.
|
||||
func NewMacInputInjector() (*MacInputInjector, error) {
|
||||
initDarwinInput()
|
||||
if !darwinInputReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available for input injection")
|
||||
}
|
||||
checkMacPermissions()
|
||||
|
||||
m := &MacInputInjector{}
|
||||
if path, err := exec.LookPath("pbcopy"); err == nil {
|
||||
m.pbcopyPath = path
|
||||
}
|
||||
if path, err := exec.LookPath("pbpaste"); err == nil {
|
||||
m.pbpastePath = path
|
||||
}
|
||||
if m.pbcopyPath == "" || m.pbpastePath == "" {
|
||||
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
|
||||
}
|
||||
|
||||
log.Info("macOS input injector ready")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// checkMacPermissions logs warnings and triggers the Accessibility prompt.
|
||||
// Screen Recording has no programmatic prompt, the user must grant it manually.
|
||||
func checkMacPermissions() {
|
||||
// Check Accessibility via osascript (triggers the system prompt dialog).
|
||||
out, err := exec.Command("osascript", "-e",
|
||||
`tell application "System Events" to return name of first process`).CombinedOutput()
|
||||
if err != nil {
|
||||
log.Warn("Accessibility permission not granted. Input injection will not work. " +
|
||||
"Grant in System Settings > Privacy & Security > Accessibility.")
|
||||
log.Debugf("accessibility check output: %s (%v)", strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
|
||||
log.Info("Screen Recording permission is required for screen capture. " +
|
||||
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release.
|
||||
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
keycode := keysymToMacKeycode(keysym)
|
||||
if keycode == 0xFFFF {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
x := float64(px)
|
||||
y := float64(py)
|
||||
leftDown := buttonMask&0x01 != 0
|
||||
rightDown := buttonMask&0x04 != 0
|
||||
middleDown := buttonMask&0x02 != 0
|
||||
scrollUp := buttonMask&0x08 != 0
|
||||
scrollDown := buttonMask&0x10 != 0
|
||||
|
||||
wasLeft := m.lastButtons&0x01 != 0
|
||||
wasRight := m.lastButtons&0x04 != 0
|
||||
wasMiddle := m.lastButtons&0x02 != 0
|
||||
|
||||
if leftDown {
|
||||
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
|
||||
} else if rightDown {
|
||||
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
|
||||
} else {
|
||||
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
|
||||
}
|
||||
|
||||
if leftDown && !wasLeft {
|
||||
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
|
||||
} else if !leftDown && wasLeft {
|
||||
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
|
||||
}
|
||||
if rightDown && !wasRight {
|
||||
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
|
||||
} else if !rightDown && wasRight {
|
||||
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
|
||||
}
|
||||
if middleDown && !wasMiddle {
|
||||
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
|
||||
} else if !middleDown && wasMiddle {
|
||||
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
|
||||
}
|
||||
|
||||
if scrollUp {
|
||||
m.postScroll(src, 3)
|
||||
}
|
||||
if scrollDown {
|
||||
m.postScroll(src, -3)
|
||||
}
|
||||
|
||||
m.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
|
||||
if cgEventCreateMouseEvent == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
|
||||
if cgEventCreateScrollWheelEventAddr == 0 {
|
||||
return
|
||||
}
|
||||
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
|
||||
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
|
||||
// Variadic C function: pass args as uintptr via SyscallN.
|
||||
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
|
||||
src, 0, 1, uintptr(uint32(deltaY)))
|
||||
if r1 == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, r1)
|
||||
cfRelease(r1)
|
||||
}
|
||||
|
||||
// SetClipboard sets the macOS clipboard using pbcopy.
|
||||
func (m *MacInputInjector) SetClipboard(text string) {
|
||||
if m.pbcopyPath == "" {
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(m.pbcopyPath)
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Tracef("set clipboard via pbcopy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClipboard reads the macOS clipboard using pbpaste.
|
||||
func (m *MacInputInjector) GetClipboard() string {
|
||||
if m.pbpastePath == "" {
|
||||
return ""
|
||||
}
|
||||
out, err := exec.Command(m.pbpastePath).Output()
|
||||
if err != nil {
|
||||
log.Tracef("get clipboard via pbpaste: %v", err)
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Close is a no-op on macOS.
|
||||
func (m *MacInputInjector) Close() {}
|
||||
|
||||
func keysymToMacKeycode(keysym uint32) uint16 {
|
||||
if keysym >= 0x61 && keysym <= 0x7a {
|
||||
return asciiToMacKey[keysym-0x61]
|
||||
}
|
||||
if keysym >= 0x41 && keysym <= 0x5a {
|
||||
return asciiToMacKey[keysym-0x41]
|
||||
}
|
||||
if keysym >= 0x30 && keysym <= 0x39 {
|
||||
return digitToMacKey[keysym-0x30]
|
||||
}
|
||||
if code, ok := specialKeyMap[keysym]; ok {
|
||||
return code
|
||||
}
|
||||
return 0xFFFF
|
||||
}
|
||||
|
||||
var asciiToMacKey = [26]uint16{
|
||||
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
|
||||
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
|
||||
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
|
||||
0x10, 0x06,
|
||||
}
|
||||
|
||||
var digitToMacKey = [10]uint16{
|
||||
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
|
||||
}
|
||||
|
||||
var specialKeyMap = map[uint32]uint16{
|
||||
// Whitespace and editing
|
||||
0x0020: 0x31, // space
|
||||
0xff08: 0x33, // BackSpace
|
||||
0xff09: 0x30, // Tab
|
||||
0xff0d: 0x24, // Return
|
||||
0xff1b: 0x35, // Escape
|
||||
0xffff: 0x75, // Delete (forward)
|
||||
|
||||
// Navigation
|
||||
0xff50: 0x73, // Home
|
||||
0xff51: 0x7B, // Left
|
||||
0xff52: 0x7E, // Up
|
||||
0xff53: 0x7C, // Right
|
||||
0xff54: 0x7D, // Down
|
||||
0xff55: 0x74, // Page_Up
|
||||
0xff56: 0x79, // Page_Down
|
||||
0xff57: 0x77, // End
|
||||
0xff63: 0x72, // Insert (Help on Mac)
|
||||
|
||||
// Modifiers
|
||||
0xffe1: 0x38, // Shift_L
|
||||
0xffe2: 0x3C, // Shift_R
|
||||
0xffe3: 0x3B, // Control_L
|
||||
0xffe4: 0x3E, // Control_R
|
||||
0xffe5: 0x39, // Caps_Lock
|
||||
0xffe9: 0x3A, // Alt_L (Option)
|
||||
0xffea: 0x3D, // Alt_R (Option)
|
||||
0xffe7: 0x37, // Meta_L (Command)
|
||||
0xffe8: 0x36, // Meta_R (Command)
|
||||
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
|
||||
0xffec: 0x36, // Super_R (Command)
|
||||
|
||||
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
|
||||
0xff7e: 0x3A, // Mode_switch -> Option
|
||||
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
|
||||
|
||||
// Function keys
|
||||
0xffbe: 0x7A, // F1
|
||||
0xffbf: 0x78, // F2
|
||||
0xffc0: 0x63, // F3
|
||||
0xffc1: 0x76, // F4
|
||||
0xffc2: 0x60, // F5
|
||||
0xffc3: 0x61, // F6
|
||||
0xffc4: 0x62, // F7
|
||||
0xffc5: 0x64, // F8
|
||||
0xffc6: 0x65, // F9
|
||||
0xffc7: 0x6D, // F10
|
||||
0xffc8: 0x67, // F11
|
||||
0xffc9: 0x6F, // F12
|
||||
0xffca: 0x69, // F13
|
||||
0xffcb: 0x6B, // F14
|
||||
0xffcc: 0x71, // F15
|
||||
0xffcd: 0x6A, // F16
|
||||
0xffce: 0x40, // F17
|
||||
0xffcf: 0x4F, // F18
|
||||
0xffd0: 0x50, // F19
|
||||
0xffd1: 0x5A, // F20
|
||||
|
||||
// Punctuation (US keyboard layout, keysym = ASCII code)
|
||||
0x002d: 0x1B, // minus -
|
||||
0x003d: 0x18, // equal =
|
||||
0x005b: 0x21, // bracketleft [
|
||||
0x005d: 0x1E, // bracketright ]
|
||||
0x005c: 0x2A, // backslash
|
||||
0x003b: 0x29, // semicolon ;
|
||||
0x0027: 0x27, // apostrophe '
|
||||
0x0060: 0x32, // grave `
|
||||
0x002c: 0x2B, // comma ,
|
||||
0x002e: 0x2F, // period .
|
||||
0x002f: 0x2C, // slash /
|
||||
|
||||
// Shifted punctuation (noVNC sends these as separate keysyms)
|
||||
0x005f: 0x1B, // underscore _ (shift+minus)
|
||||
0x002b: 0x18, // plus + (shift+equal)
|
||||
0x007b: 0x21, // braceleft { (shift+[)
|
||||
0x007d: 0x1E, // braceright } (shift+])
|
||||
0x007c: 0x2A, // bar | (shift+\)
|
||||
0x003a: 0x29, // colon : (shift+;)
|
||||
0x0022: 0x27, // quotedbl " (shift+')
|
||||
0x007e: 0x32, // tilde ~ (shift+`)
|
||||
0x003c: 0x2B, // less < (shift+,)
|
||||
0x003e: 0x2F, // greater > (shift+.)
|
||||
0x003f: 0x2C, // question ? (shift+/)
|
||||
0x0021: 0x12, // exclam ! (shift+1)
|
||||
0x0040: 0x13, // at @ (shift+2)
|
||||
0x0023: 0x14, // numbersign # (shift+3)
|
||||
0x0024: 0x15, // dollar $ (shift+4)
|
||||
0x0025: 0x17, // percent % (shift+5)
|
||||
0x005e: 0x16, // asciicircum ^ (shift+6)
|
||||
0x0026: 0x1A, // ampersand & (shift+7)
|
||||
0x002a: 0x1C, // asterisk * (shift+8)
|
||||
0x0028: 0x19, // parenleft ( (shift+9)
|
||||
0x0029: 0x1D, // parenright ) (shift+0)
|
||||
|
||||
// Numpad
|
||||
0xffb0: 0x52, // KP_0
|
||||
0xffb1: 0x53, // KP_1
|
||||
0xffb2: 0x54, // KP_2
|
||||
0xffb3: 0x55, // KP_3
|
||||
0xffb4: 0x56, // KP_4
|
||||
0xffb5: 0x57, // KP_5
|
||||
0xffb6: 0x58, // KP_6
|
||||
0xffb7: 0x59, // KP_7
|
||||
0xffb8: 0x5B, // KP_8
|
||||
0xffb9: 0x5C, // KP_9
|
||||
0xffae: 0x41, // KP_Decimal
|
||||
0xffaa: 0x43, // KP_Multiply
|
||||
0xffab: 0x45, // KP_Add
|
||||
0xffad: 0x4E, // KP_Subtract
|
||||
0xffaf: 0x4B, // KP_Divide
|
||||
0xff8d: 0x4C, // KP_Enter
|
||||
0xffbd: 0x51, // KP_Equal
|
||||
}
|
||||
|
||||
var _ InputInjector = (*MacInputInjector)(nil)
|
||||
398
client/vnc/server/input_windows.go
Normal file
398
client/vnc/server/input_windows.go
Normal file
@@ -0,0 +1,398 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procOpenEventW = kernel32.NewProc("OpenEventW")
|
||||
procSendInput = user32.NewProc("SendInput")
|
||||
procVkKeyScanA = user32.NewProc("VkKeyScanA")
|
||||
)
|
||||
|
||||
const eventModifyState = 0x0002
|
||||
|
||||
const (
|
||||
inputMouse = 0
|
||||
inputKeyboard = 1
|
||||
|
||||
mouseeventfMove = 0x0001
|
||||
mouseeventfLeftDown = 0x0002
|
||||
mouseeventfLeftUp = 0x0004
|
||||
mouseeventfRightDown = 0x0008
|
||||
mouseeventfRightUp = 0x0010
|
||||
mouseeventfMiddleDown = 0x0020
|
||||
mouseeventfMiddleUp = 0x0040
|
||||
mouseeventfWheel = 0x0800
|
||||
mouseeventfAbsolute = 0x8000
|
||||
|
||||
wheelDelta = 120
|
||||
|
||||
keyeventfKeyUp = 0x0002
|
||||
keyeventfScanCode = 0x0008
|
||||
)
|
||||
|
||||
type mouseInput struct {
|
||||
Dx int32
|
||||
Dy int32
|
||||
MouseData uint32
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
}
|
||||
|
||||
type keybdInput struct {
|
||||
WVk uint16
|
||||
WScan uint16
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
_ [8]byte
|
||||
}
|
||||
|
||||
type inputUnion [32]byte
|
||||
|
||||
type winInput struct {
|
||||
Type uint32
|
||||
_ [4]byte
|
||||
Data inputUnion
|
||||
}
|
||||
|
||||
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
|
||||
mi := mouseInput{
|
||||
Dx: dx,
|
||||
Dy: dy,
|
||||
MouseData: mouseData,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputMouse}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
|
||||
ki := keybdInput{
|
||||
WVk: vk,
|
||||
WScan: scanCode,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputKeyboard}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
|
||||
}
|
||||
}
|
||||
|
||||
const sasEventName = `Global\NetBirdVNC_SAS`
|
||||
|
||||
type inputCmd struct {
|
||||
isKey bool
|
||||
keysym uint32
|
||||
down bool
|
||||
buttonMask uint8
|
||||
x, y int
|
||||
serverW int
|
||||
serverH int
|
||||
}
|
||||
|
||||
// WindowsInputInjector delivers input events from a dedicated OS thread that
|
||||
// calls switchToInputDesktop before each injection. SendInput targets the
|
||||
// calling thread's desktop, so the injection thread must be on the same
|
||||
// desktop the user sees.
|
||||
type WindowsInputInjector struct {
|
||||
ch chan inputCmd
|
||||
prevButtonMask uint8
|
||||
ctrlDown bool
|
||||
altDown bool
|
||||
}
|
||||
|
||||
// NewWindowsInputInjector creates a desktop-aware input injector.
|
||||
func NewWindowsInputInjector() *WindowsInputInjector {
|
||||
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
|
||||
go w.loop()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) loop() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
for cmd := range w.ch {
|
||||
// Switch to the current input desktop so SendInput reaches the right target.
|
||||
switchToInputDesktop()
|
||||
|
||||
if cmd.isKey {
|
||||
w.doInjectKey(cmd.keysym, cmd.down)
|
||||
} else {
|
||||
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey queues a key event for injection on the input desktop thread.
|
||||
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
|
||||
}
|
||||
|
||||
// InjectPointer queues a pointer event for injection on the input desktop thread.
|
||||
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
|
||||
switch keysym {
|
||||
case 0xffe3, 0xffe4:
|
||||
w.ctrlDown = down
|
||||
case 0xffe9, 0xffea:
|
||||
w.altDown = down
|
||||
}
|
||||
|
||||
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||
signalSAS()
|
||||
return
|
||||
}
|
||||
|
||||
vk, _, extended := keysym2VK(keysym)
|
||||
if vk == 0 {
|
||||
return
|
||||
}
|
||||
var flags uint32
|
||||
if !down {
|
||||
flags |= keyeventfKeyUp
|
||||
}
|
||||
if extended {
|
||||
flags |= keyeventfScanCode
|
||||
}
|
||||
sendKeyInput(vk, 0, flags)
|
||||
}
|
||||
|
||||
// signalSAS signals the SAS named event. A listener in Session 0
|
||||
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
|
||||
func signalSAS() {
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS UTF16: %v", err)
|
||||
return
|
||||
}
|
||||
h, _, lerr := procOpenEventW.Call(
|
||||
uintptr(eventModifyState),
|
||||
0,
|
||||
uintptr(unsafe.Pointer(namePtr)),
|
||||
)
|
||||
if h == 0 {
|
||||
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
|
||||
return
|
||||
}
|
||||
ev := windows.Handle(h)
|
||||
defer windows.CloseHandle(ev)
|
||||
if err := windows.SetEvent(ev); err != nil {
|
||||
log.Warnf("SetEvent SAS: %v", err)
|
||||
} else {
|
||||
log.Info("SAS event signaled")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
absX := int32(x * 65535 / serverW)
|
||||
absY := int32(y * 65535 / serverH)
|
||||
|
||||
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
|
||||
|
||||
changed := buttonMask ^ w.prevButtonMask
|
||||
w.prevButtonMask = buttonMask
|
||||
|
||||
type btnMap struct {
|
||||
bit uint8
|
||||
down uint32
|
||||
up uint32
|
||||
}
|
||||
buttons := [...]btnMap{
|
||||
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
|
||||
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
|
||||
{0x04, mouseeventfRightDown, mouseeventfRightUp},
|
||||
}
|
||||
for _, b := range buttons {
|
||||
if changed&b.bit == 0 {
|
||||
continue
|
||||
}
|
||||
var flags uint32
|
||||
if buttonMask&b.bit != 0 {
|
||||
flags = b.down
|
||||
} else {
|
||||
flags = b.up
|
||||
}
|
||||
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
|
||||
}
|
||||
|
||||
negWheelDelta := ^uint32(wheelDelta - 1)
|
||||
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
|
||||
}
|
||||
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
|
||||
}
|
||||
}
|
||||
|
||||
// keysym2VK converts an X11 keysym to a Windows virtual key code.
|
||||
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
|
||||
if keysym >= 0x20 && keysym <= 0x7e {
|
||||
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
|
||||
vk = uint16(r & 0xff)
|
||||
return
|
||||
}
|
||||
|
||||
if keysym >= 0xffbe && keysym <= 0xffc9 {
|
||||
vk = uint16(0x70 + keysym - 0xffbe)
|
||||
return
|
||||
}
|
||||
|
||||
switch keysym {
|
||||
case 0xff08:
|
||||
vk = 0x08 // Backspace
|
||||
case 0xff09:
|
||||
vk = 0x09 // Tab
|
||||
case 0xff0d:
|
||||
vk = 0x0d // Return
|
||||
case 0xff1b:
|
||||
vk = 0x1b // Escape
|
||||
case 0xff63:
|
||||
vk, extended = 0x2d, true // Insert
|
||||
case 0xff9f, 0xffff:
|
||||
vk, extended = 0x2e, true // Delete
|
||||
case 0xff50:
|
||||
vk, extended = 0x24, true // Home
|
||||
case 0xff57:
|
||||
vk, extended = 0x23, true // End
|
||||
case 0xff55:
|
||||
vk, extended = 0x21, true // PageUp
|
||||
case 0xff56:
|
||||
vk, extended = 0x22, true // PageDown
|
||||
case 0xff51:
|
||||
vk, extended = 0x25, true // Left
|
||||
case 0xff52:
|
||||
vk, extended = 0x26, true // Up
|
||||
case 0xff53:
|
||||
vk, extended = 0x27, true // Right
|
||||
case 0xff54:
|
||||
vk, extended = 0x28, true // Down
|
||||
case 0xffe1, 0xffe2:
|
||||
vk = 0x10 // Shift
|
||||
case 0xffe3, 0xffe4:
|
||||
vk = 0x11 // Control
|
||||
case 0xffe9, 0xffea:
|
||||
vk = 0x12 // Alt
|
||||
case 0xffe5:
|
||||
vk = 0x14 // CapsLock
|
||||
case 0xffe7, 0xffeb:
|
||||
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
|
||||
case 0xffe8, 0xffec:
|
||||
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
|
||||
case 0xff61:
|
||||
vk = 0x2c // PrintScreen
|
||||
case 0xff13:
|
||||
vk = 0x13 // Pause
|
||||
case 0xff14:
|
||||
vk = 0x91 // ScrollLock
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
procOpenClipboard = user32.NewProc("OpenClipboard")
|
||||
procCloseClipboard = user32.NewProc("CloseClipboard")
|
||||
procEmptyClipboard = user32.NewProc("EmptyClipboard")
|
||||
procSetClipboardData = user32.NewProc("SetClipboardData")
|
||||
procGetClipboardData = user32.NewProc("GetClipboardData")
|
||||
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
|
||||
|
||||
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
|
||||
procGlobalLock = kernel32.NewProc("GlobalLock")
|
||||
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
|
||||
)
|
||||
|
||||
const (
|
||||
cfUnicodeText = 13
|
||||
gmemMoveable = 0x0002
|
||||
)
|
||||
|
||||
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
|
||||
func (w *WindowsInputInjector) SetClipboard(text string) {
|
||||
utf16, err := windows.UTF16FromString(text)
|
||||
if err != nil {
|
||||
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
size := uintptr(len(utf16) * 2)
|
||||
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
|
||||
if hMem == 0 {
|
||||
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hMem)
|
||||
if ptr == 0 {
|
||||
log.Tracef("GlobalLock for clipboard: lock returned nil")
|
||||
return
|
||||
}
|
||||
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
|
||||
procGlobalUnlock.Call(hMem)
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard: %v", lerr)
|
||||
return
|
||||
}
|
||||
defer procCloseClipboard.Call()
|
||||
|
||||
procEmptyClipboard.Call()
|
||||
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
|
||||
if r == 0 {
|
||||
log.Tracef("SetClipboardData: %v", lerr)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClipboard reads the Windows clipboard as UTF-8 text.
|
||||
func (w *WindowsInputInjector) GetClipboard() string {
|
||||
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
|
||||
if r == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard for read: %v", lerr)
|
||||
return ""
|
||||
}
|
||||
defer procCloseClipboard.Call()
|
||||
|
||||
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
|
||||
if hData == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hData)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
defer procGlobalUnlock.Call(hData)
|
||||
|
||||
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
|
||||
}
|
||||
|
||||
var _ InputInjector = (*WindowsInputInjector)(nil)
|
||||
|
||||
var _ ScreenCapturer = (*DesktopCapturer)(nil)
|
||||
242
client/vnc/server/input_x11.go
Normal file
242
client/vnc/server/input_x11.go
Normal file
@@ -0,0 +1,242 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
"github.com/jezek/xgb/xtest"
|
||||
)
|
||||
|
||||
// X11InputInjector injects keyboard and mouse events via the XTest extension.
|
||||
type X11InputInjector struct {
|
||||
conn *xgb.Conn
|
||||
root xproto.Window
|
||||
screen *xproto.ScreenInfo
|
||||
display string
|
||||
keysymMap map[uint32]byte
|
||||
lastButtons uint8
|
||||
clipboardTool string
|
||||
clipboardToolName string
|
||||
}
|
||||
|
||||
// NewX11InputInjector connects to the X11 display and initializes XTest.
|
||||
func NewX11InputInjector(display string) (*X11InputInjector, error) {
|
||||
detectX11Display()
|
||||
|
||||
if display == "" {
|
||||
display = os.Getenv("DISPLAY")
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
if err := xtest.Init(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("init XTest extension: %w", err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
inj := &X11InputInjector{
|
||||
conn: conn,
|
||||
root: screen.Root,
|
||||
screen: &screen,
|
||||
display: display,
|
||||
}
|
||||
inj.cacheKeyboardMapping()
|
||||
inj.resolveClipboardTool()
|
||||
|
||||
log.Infof("X11 input injector ready (display=%s)", display)
|
||||
return inj, nil
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
|
||||
keycode := x.keysymToKeycode(keysym)
|
||||
if keycode == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var eventType byte
|
||||
if down {
|
||||
eventType = xproto.KeyPress
|
||||
} else {
|
||||
eventType = xproto.KeyRelease
|
||||
}
|
||||
|
||||
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Scale to actual screen coordinates.
|
||||
screenW := int(x.screen.WidthInPixels)
|
||||
screenH := int(x.screen.HeightInPixels)
|
||||
absX := px * screenW / serverW
|
||||
absY := py * screenH / serverH
|
||||
|
||||
// Move pointer.
|
||||
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
|
||||
|
||||
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
|
||||
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
|
||||
// 4=scrollUp, 5=scrollDown.
|
||||
type btnMap struct {
|
||||
rfbBit uint8
|
||||
x11Btn byte
|
||||
}
|
||||
buttons := [...]btnMap{
|
||||
{0x01, 1}, // left
|
||||
{0x02, 2}, // middle
|
||||
{0x04, 3}, // right
|
||||
{0x08, 4}, // scroll up
|
||||
{0x10, 5}, // scroll down
|
||||
}
|
||||
|
||||
for _, b := range buttons {
|
||||
pressed := buttonMask&b.rfbBit != 0
|
||||
wasPressed := x.lastButtons&b.rfbBit != 0
|
||||
if b.x11Btn >= 4 {
|
||||
// Scroll: send press+release on each scroll event.
|
||||
if pressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
} else {
|
||||
if pressed && !wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
} else if !pressed && wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
x.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
|
||||
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
|
||||
func (x *X11InputInjector) cacheKeyboardMapping() {
|
||||
setup := xproto.Setup(x.conn)
|
||||
minKeycode := setup.MinKeycode
|
||||
maxKeycode := setup.MaxKeycode
|
||||
|
||||
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
|
||||
byte(maxKeycode-minKeycode+1)).Reply()
|
||||
if err != nil {
|
||||
log.Debugf("cache keyboard mapping: %v", err)
|
||||
x.keysymMap = make(map[uint32]byte)
|
||||
return
|
||||
}
|
||||
|
||||
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
|
||||
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
|
||||
for i := int(minKeycode); i <= int(maxKeycode); i++ {
|
||||
offset := (i - int(minKeycode)) * keysymsPerKeycode
|
||||
for j := 0; j < keysymsPerKeycode; j++ {
|
||||
ks := uint32(reply.Keysyms[offset+j])
|
||||
if ks != 0 {
|
||||
if _, exists := m[ks]; !exists {
|
||||
m[ks] = byte(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
x.keysymMap = m
|
||||
}
|
||||
|
||||
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
|
||||
// Returns 0 if the keysym is not mapped.
|
||||
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
|
||||
return x.keysymMap[keysym]
|
||||
}
|
||||
|
||||
// SetClipboard sets the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) SetClipboard(text string) {
|
||||
if x.clipboardTool == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) resolveClipboardTool() {
|
||||
for _, name := range []string{"xclip", "xsel"} {
|
||||
path, err := exec.LookPath(name)
|
||||
if err == nil {
|
||||
x.clipboardTool = path
|
||||
x.clipboardToolName = name
|
||||
log.Debugf("clipboard tool resolved to %s", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
|
||||
}
|
||||
|
||||
// GetClipboard reads the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) GetClipboard() string {
|
||||
if x.clipboardTool == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) clipboardEnv() []string {
|
||||
env := []string{"DISPLAY=" + x.display}
|
||||
if auth := os.Getenv("XAUTHORITY"); auth != "" {
|
||||
env = append(env, "XAUTHORITY="+auth)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (x *X11InputInjector) Close() {
|
||||
x.conn.Close()
|
||||
}
|
||||
|
||||
var _ InputInjector = (*X11InputInjector)(nil)
|
||||
var _ ScreenCapturer = (*X11Poller)(nil)
|
||||
264
client/vnc/server/rfb.go
Normal file
264
client/vnc/server/rfb.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"crypto/des"
|
||||
"encoding/binary"
|
||||
"image"
|
||||
)
|
||||
|
||||
const (
|
||||
rfbProtocolVersion = "RFB 003.008\n"
|
||||
|
||||
secNone = 1
|
||||
secVNCAuth = 2
|
||||
|
||||
// Client message types.
|
||||
clientSetPixelFormat = 0
|
||||
clientSetEncodings = 2
|
||||
clientFramebufferUpdateRequest = 3
|
||||
clientKeyEvent = 4
|
||||
clientPointerEvent = 5
|
||||
clientCutText = 6
|
||||
|
||||
// Server message types.
|
||||
serverFramebufferUpdate = 0
|
||||
serverCutText = 3
|
||||
|
||||
// Encoding types.
|
||||
encRaw = 0
|
||||
encZlib = 6
|
||||
)
|
||||
|
||||
// serverPixelFormat is the default pixel format advertised by the server:
|
||||
// 32bpp RGBA, big-endian, true-colour, 8 bits per channel.
|
||||
var serverPixelFormat = [16]byte{
|
||||
32, // bits-per-pixel
|
||||
24, // depth
|
||||
1, // big-endian-flag
|
||||
1, // true-colour-flag
|
||||
0, 255, // red-max
|
||||
0, 255, // green-max
|
||||
0, 255, // blue-max
|
||||
16, // red-shift
|
||||
8, // green-shift
|
||||
0, // blue-shift
|
||||
0, 0, 0, // padding
|
||||
}
|
||||
|
||||
// clientPixelFormat holds the negotiated pixel format from the client.
|
||||
type clientPixelFormat struct {
|
||||
bpp uint8
|
||||
bigEndian uint8
|
||||
rMax uint16
|
||||
gMax uint16
|
||||
bMax uint16
|
||||
rShift uint8
|
||||
gShift uint8
|
||||
bShift uint8
|
||||
}
|
||||
|
||||
func defaultClientPixelFormat() clientPixelFormat {
|
||||
return clientPixelFormat{
|
||||
bpp: serverPixelFormat[0],
|
||||
bigEndian: serverPixelFormat[2],
|
||||
rMax: binary.BigEndian.Uint16(serverPixelFormat[4:6]),
|
||||
gMax: binary.BigEndian.Uint16(serverPixelFormat[6:8]),
|
||||
bMax: binary.BigEndian.Uint16(serverPixelFormat[8:10]),
|
||||
rShift: serverPixelFormat[10],
|
||||
gShift: serverPixelFormat[11],
|
||||
bShift: serverPixelFormat[12],
|
||||
}
|
||||
}
|
||||
|
||||
func parsePixelFormat(pf []byte) clientPixelFormat {
|
||||
return clientPixelFormat{
|
||||
bpp: pf[0],
|
||||
bigEndian: pf[2],
|
||||
rMax: binary.BigEndian.Uint16(pf[4:6]),
|
||||
gMax: binary.BigEndian.Uint16(pf[6:8]),
|
||||
bMax: binary.BigEndian.Uint16(pf[8:10]),
|
||||
rShift: pf[10],
|
||||
gShift: pf[11],
|
||||
bShift: pf[12],
|
||||
}
|
||||
}
|
||||
|
||||
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
|
||||
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
|
||||
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
|
||||
bytesPerPixel := max(int(pf.bpp)/8, 1)
|
||||
|
||||
pixelBytes := w * h * bytesPerPixel
|
||||
buf := make([]byte, 4+12+pixelBytes)
|
||||
|
||||
// FramebufferUpdate header.
|
||||
buf[0] = serverFramebufferUpdate
|
||||
buf[1] = 0 // padding
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
|
||||
// Rectangle header.
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
|
||||
|
||||
off := 16
|
||||
stride := img.Stride
|
||||
for row := y; row < y+h; row++ {
|
||||
for col := x; col < x+w; col++ {
|
||||
p := row*stride + col*4
|
||||
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
|
||||
|
||||
rv := uint32(r) * uint32(pf.rMax) / 255
|
||||
gv := uint32(g) * uint32(pf.gMax) / 255
|
||||
bv := uint32(b) * uint32(pf.bMax) / 255
|
||||
pixel := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
|
||||
|
||||
if pf.bigEndian != 0 {
|
||||
for i := range bytesPerPixel {
|
||||
buf[off+i] = byte(pixel >> uint((bytesPerPixel-1-i)*8))
|
||||
}
|
||||
} else {
|
||||
for i := range bytesPerPixel {
|
||||
buf[off+i] = byte(pixel >> uint(i*8))
|
||||
}
|
||||
}
|
||||
off += bytesPerPixel
|
||||
}
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// vncAuthEncrypt encrypts a 16-byte challenge using the VNC DES scheme.
|
||||
func vncAuthEncrypt(challenge []byte, password string) []byte {
|
||||
key := make([]byte, 8)
|
||||
for i, c := range []byte(password) {
|
||||
if i >= 8 {
|
||||
break
|
||||
}
|
||||
key[i] = reverseBits(c)
|
||||
}
|
||||
block, _ := des.NewCipher(key)
|
||||
out := make([]byte, 16)
|
||||
block.Encrypt(out[:8], challenge[:8])
|
||||
block.Encrypt(out[8:], challenge[8:])
|
||||
return out
|
||||
}
|
||||
|
||||
func reverseBits(b byte) byte {
|
||||
var r byte
|
||||
for range 8 {
|
||||
r = (r << 1) | (b & 1)
|
||||
b >>= 1
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// encodeZlibRect encodes a framebuffer region using Zlib compression.
|
||||
// The zlib stream is continuous for the entire VNC session: noVNC creates
|
||||
// one inflate context at startup and reuses it for all zlib-encoded rects.
|
||||
// We must NOT reset the zlib writer between calls.
|
||||
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, zw *zlib.Writer, zbuf *bytes.Buffer) []byte {
|
||||
bytesPerPixel := max(int(pf.bpp)/8, 1)
|
||||
|
||||
// Clear the output buffer but keep the deflate dictionary intact.
|
||||
zbuf.Reset()
|
||||
|
||||
stride := img.Stride
|
||||
pixel := make([]byte, bytesPerPixel)
|
||||
for row := y; row < y+h; row++ {
|
||||
for col := x; col < x+w; col++ {
|
||||
p := row*stride + col*4
|
||||
r, g, b := img.Pix[p], img.Pix[p+1], img.Pix[p+2]
|
||||
|
||||
rv := uint32(r) * uint32(pf.rMax) / 255
|
||||
gv := uint32(g) * uint32(pf.gMax) / 255
|
||||
bv := uint32(b) * uint32(pf.bMax) / 255
|
||||
val := (rv << pf.rShift) | (gv << pf.gShift) | (bv << pf.bShift)
|
||||
|
||||
if pf.bigEndian != 0 {
|
||||
for i := range bytesPerPixel {
|
||||
pixel[i] = byte(val >> uint((bytesPerPixel-1-i)*8))
|
||||
}
|
||||
} else {
|
||||
for i := range bytesPerPixel {
|
||||
pixel[i] = byte(val >> uint(i*8))
|
||||
}
|
||||
}
|
||||
zw.Write(pixel)
|
||||
}
|
||||
}
|
||||
zw.Flush()
|
||||
|
||||
compressed := zbuf.Bytes()
|
||||
|
||||
// Build the FramebufferUpdate message.
|
||||
buf := make([]byte, 4+12+4+len(compressed))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
buf[1] = 0
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1) // 1 rectangle
|
||||
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
|
||||
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
|
||||
copy(buf[20:], compressed)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// diffRects compares two RGBA images and returns a list of dirty rectangles.
|
||||
// Divides the screen into tiles and checks each for changes.
|
||||
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||
if prev == nil {
|
||||
return [][4]int{{0, 0, w, h}}
|
||||
}
|
||||
|
||||
var rects [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := min(tileSize, h-ty)
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := min(tileSize, w-tx)
|
||||
if tileChanged(prev, cur, tx, ty, tw, th) {
|
||||
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
}
|
||||
return rects
|
||||
}
|
||||
|
||||
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
|
||||
stride := prev.Stride
|
||||
for row := y; row < y+h; row++ {
|
||||
off := row*stride + x*4
|
||||
end := off + w*4
|
||||
prevRow := prev.Pix[off:end]
|
||||
curRow := cur.Pix[off:end]
|
||||
if !bytes.Equal(prevRow, curRow) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// zlibState holds the persistent zlib writer and buffer for a session.
|
||||
type zlibState struct {
|
||||
buf *bytes.Buffer
|
||||
w *zlib.Writer
|
||||
}
|
||||
|
||||
func newZlibState() *zlibState {
|
||||
buf := &bytes.Buffer{}
|
||||
w, _ := zlib.NewWriterLevel(buf, zlib.BestSpeed)
|
||||
return &zlibState{buf: buf, w: w}
|
||||
}
|
||||
|
||||
func (z *zlibState) Close() error {
|
||||
return z.w.Close()
|
||||
}
|
||||
605
client/vnc/server/server.go
Normal file
605
client/vnc/server/server.go
Normal file
@@ -0,0 +1,605 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
// Connection modes sent by the client in the session header.
|
||||
const (
|
||||
ModeAttach byte = 0 // Capture current display
|
||||
ModeSession byte = 1 // Virtual session as specified user
|
||||
)
|
||||
|
||||
// ScreenCapturer grabs desktop frames for the VNC server.
|
||||
type ScreenCapturer interface {
|
||||
// Width returns the current screen width in pixels.
|
||||
Width() int
|
||||
// Height returns the current screen height in pixels.
|
||||
Height() int
|
||||
// Capture returns the current desktop as an RGBA image.
|
||||
Capture() (*image.RGBA, error)
|
||||
}
|
||||
|
||||
// InputInjector delivers keyboard and mouse events to the OS.
|
||||
type InputInjector interface {
|
||||
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||
InjectKey(keysym uint32, down bool)
|
||||
// InjectPointer simulates mouse movement and button state.
|
||||
InjectPointer(buttonMask uint8, x, y, serverW, serverH int)
|
||||
// SetClipboard sets the system clipboard to the given text.
|
||||
SetClipboard(text string)
|
||||
// GetClipboard returns the current system clipboard text.
|
||||
GetClipboard() string
|
||||
}
|
||||
|
||||
// JWTConfig holds JWT validation configuration for VNC auth.
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
Audiences []string
|
||||
}
|
||||
|
||||
// connectionHeader is sent by the client before the RFB handshake to specify
|
||||
// the VNC session mode and authenticate.
|
||||
type connectionHeader struct {
|
||||
mode byte
|
||||
username string
|
||||
jwt string
|
||||
sessionID uint32 // Windows session ID (0 = console/auto)
|
||||
}
|
||||
|
||||
// Server is the embedded VNC server that listens on the WireGuard interface.
|
||||
// It supports two operating modes:
|
||||
// - Direct mode: captures the screen and handles VNC sessions in-process.
|
||||
// Used when running in a user session with desktop access.
|
||||
// - Service mode: proxies VNC connections to an agent process spawned in
|
||||
// the active console session. Used when running as a Windows service in
|
||||
// Session 0.
|
||||
//
|
||||
// Within direct mode, each connection can request one of two session modes
|
||||
// via the connection header:
|
||||
// - Attach: capture the current physical display.
|
||||
// - Session: start a virtual Xvfb display as the requested user.
|
||||
type Server struct {
|
||||
capturer ScreenCapturer
|
||||
injector InputInjector
|
||||
password string
|
||||
serviceMode bool
|
||||
disableAuth bool
|
||||
localAddr netip.Addr // NetBird WireGuard IP this server is bound to
|
||||
network netip.Prefix // NetBird overlay network
|
||||
log *log.Entry
|
||||
|
||||
mu sync.Mutex
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
vmgr virtualSessionManager
|
||||
jwtConfig *JWTConfig
|
||||
jwtValidator *nbjwt.Validator
|
||||
jwtExtractor *nbjwt.ClaimsExtractor
|
||||
authorizer *sshauth.Authorizer
|
||||
netstackNet *netstack.Net
|
||||
agentToken []byte // raw token bytes for agent-mode auth
|
||||
}
|
||||
|
||||
// vncSession provides capturer and injector for a virtual display session.
|
||||
type vncSession interface {
|
||||
Capturer() ScreenCapturer
|
||||
Injector() InputInjector
|
||||
Display() string
|
||||
ClientConnect()
|
||||
ClientDisconnect()
|
||||
}
|
||||
|
||||
// virtualSessionManager is implemented by sessionManager on Linux.
|
||||
type virtualSessionManager interface {
|
||||
GetOrCreate(username string) (vncSession, error)
|
||||
StopAll()
|
||||
}
|
||||
|
||||
// New creates a VNC server with the given screen capturer and input injector.
|
||||
func New(capturer ScreenCapturer, injector InputInjector, password string) *Server {
|
||||
return &Server{
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
password: password,
|
||||
authorizer: sshauth.NewAuthorizer(),
|
||||
log: log.WithField("component", "vnc-server"),
|
||||
}
|
||||
}
|
||||
|
||||
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
|
||||
func (s *Server) SetServiceMode(enabled bool) {
|
||||
s.serviceMode = enabled
|
||||
}
|
||||
|
||||
// SetJWTConfig configures JWT authentication for VNC connections.
|
||||
// Pass nil to disable JWT (public mode).
|
||||
func (s *Server) SetJWTConfig(config *JWTConfig) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jwtConfig = config
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
}
|
||||
|
||||
// SetDisableAuth disables authentication entirely.
|
||||
func (s *Server) SetDisableAuth(disable bool) {
|
||||
s.disableAuth = disable
|
||||
}
|
||||
|
||||
// SetAgentToken sets a hex-encoded token that must be presented by incoming
|
||||
// connections before any VNC data. Used in agent mode to verify that only the
|
||||
// trusted service process connects.
|
||||
func (s *Server) SetAgentToken(hexToken string) {
|
||||
if hexToken == "" {
|
||||
return
|
||||
}
|
||||
b, err := hex.DecodeString(hexToken)
|
||||
if err != nil {
|
||||
s.log.Warnf("invalid agent token: %v", err)
|
||||
return
|
||||
}
|
||||
s.agentToken = b
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace-only listening.
|
||||
// When set, the VNC server listens via netstack instead of a real OS socket.
|
||||
func (s *Server) SetNetstackNet(n *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = n
|
||||
}
|
||||
|
||||
// UpdateVNCAuth updates the fine-grained authorization configuration.
|
||||
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.jwtValidator = nil
|
||||
s.jwtExtractor = nil
|
||||
s.authorizer.Update(config)
|
||||
}
|
||||
|
||||
// Start begins listening for VNC connections on the given address.
|
||||
// network is the NetBird overlay prefix used to validate connection sources.
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.listener != nil {
|
||||
return fmt.Errorf("server already running")
|
||||
}
|
||||
|
||||
s.ctx, s.cancel = context.WithCancel(ctx)
|
||||
s.vmgr = s.platformSessionManager()
|
||||
s.localAddr = addr.Addr()
|
||||
s.network = network
|
||||
|
||||
var listener net.Listener
|
||||
var listenDesc string
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on netstack %s: %w", addr, err)
|
||||
}
|
||||
listener = ln
|
||||
listenDesc = fmt.Sprintf("netstack %s", addr)
|
||||
} else {
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addr)
|
||||
ln, err := net.ListenTCP("tcp", tcpAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", addr, err)
|
||||
}
|
||||
listener = ln
|
||||
listenDesc = addr.String()
|
||||
}
|
||||
s.listener = listener
|
||||
|
||||
if s.serviceMode {
|
||||
s.platformInit()
|
||||
}
|
||||
|
||||
if s.serviceMode {
|
||||
go s.serviceAcceptLoop()
|
||||
} else {
|
||||
go s.acceptLoop()
|
||||
}
|
||||
|
||||
s.log.Infof("started on %s (service_mode=%v)", listenDesc, s.serviceMode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop shuts down the server and closes all connections.
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
s.cancel = nil
|
||||
}
|
||||
|
||||
if s.vmgr != nil {
|
||||
s.vmgr.StopAll()
|
||||
}
|
||||
|
||||
if c, ok := s.capturer.(interface{ Close() }); ok {
|
||||
c.Close()
|
||||
}
|
||||
|
||||
if s.listener != nil {
|
||||
err := s.listener.Close()
|
||||
s.listener = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("close VNC listener: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Info("stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop handles VNC connections directly (user session mode).
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) validateCapturer(cap ScreenCapturer) error {
|
||||
// Quick check first: if already ready, return immediately.
|
||||
if cap.Width() > 0 && cap.Height() > 0 {
|
||||
return nil
|
||||
}
|
||||
// Wait up to 5s for the capturer to become ready.
|
||||
for range 50 {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if cap.Width() > 0 && cap.Height() > 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("no display available (X11 not running or screen recording not permitted)")
|
||||
}
|
||||
|
||||
// isAllowedSource rejects connections from outside the NetBird overlay network
|
||||
// and from the local WireGuard IP (prevents local privilege escalation).
|
||||
// Matches the SSH server's connectionValidator logic.
|
||||
func (s *Server) isAllowedSource(addr net.Addr) bool {
|
||||
tcpAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
s.log.Warnf("connection rejected: non-TCP address %s", addr)
|
||||
return false
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
s.log.Warnf("connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||
return false
|
||||
}
|
||||
remoteIP = remoteIP.Unmap()
|
||||
|
||||
if remoteIP.IsLoopback() && s.localAddr.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
if remoteIP == s.localAddr {
|
||||
s.log.Warnf("connection rejected from own IP %s", remoteIP)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.network.IsValid() && !s.network.Contains(remoteIP) {
|
||||
s.log.Warnf("connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if len(s.agentToken) > 0 {
|
||||
buf := make([]byte, len(s.agentToken))
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
connLog.Debugf("set agent token deadline: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
connLog.Warnf("agent auth: read token: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
if subtle.ConstantTimeCompare(buf, s.agentToken) != 1 {
|
||||
connLog.Warn("agent auth: invalid token, rejecting")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
header, err := readConnectionHeader(conn)
|
||||
if err != nil {
|
||||
connLog.Warnf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, "auth enabled but no identity provider configured")
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return
|
||||
}
|
||||
jwtUserID, err := s.authenticateJWT(header)
|
||||
if err != nil {
|
||||
rejectConnection(conn, fmt.Sprintf("auth: %v", err))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
connLog = connLog.WithField("jwt_user", jwtUserID)
|
||||
}
|
||||
|
||||
var capturer ScreenCapturer
|
||||
var injector InputInjector
|
||||
|
||||
switch header.mode {
|
||||
case ModeSession:
|
||||
if s.vmgr == nil {
|
||||
rejectConnection(conn, "virtual sessions not supported on this platform")
|
||||
connLog.Warn("session rejected: not supported on this platform")
|
||||
return
|
||||
}
|
||||
if header.username == "" {
|
||||
rejectConnection(conn, "session mode requires a username")
|
||||
connLog.Warn("session rejected: no username provided")
|
||||
return
|
||||
}
|
||||
vs, err := s.vmgr.GetOrCreate(header.username)
|
||||
if err != nil {
|
||||
rejectConnection(conn, fmt.Sprintf("create virtual session: %v", err))
|
||||
connLog.Warnf("create virtual session for %s: %v", header.username, err)
|
||||
return
|
||||
}
|
||||
capturer = vs.Capturer()
|
||||
injector = vs.Injector()
|
||||
vs.ClientConnect()
|
||||
defer vs.ClientDisconnect()
|
||||
connLog = connLog.WithField("vnc_user", header.username)
|
||||
connLog.Infof("session mode: user=%s display=%s", header.username, vs.Display())
|
||||
|
||||
default:
|
||||
capturer = s.capturer
|
||||
injector = s.injector
|
||||
if cc, ok := capturer.(interface{ ClientConnect() }); ok {
|
||||
cc.ClientConnect()
|
||||
}
|
||||
defer func() {
|
||||
if cd, ok := capturer.(interface{ ClientDisconnect() }); ok {
|
||||
cd.ClientDisconnect()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := s.validateCapturer(capturer); err != nil {
|
||||
rejectConnection(conn, fmt.Sprintf("screen capturer: %v", err))
|
||||
connLog.Warnf("capturer not ready: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sess := &session{
|
||||
conn: conn,
|
||||
capturer: capturer,
|
||||
injector: injector,
|
||||
serverW: capturer.Width(),
|
||||
serverH: capturer.Height(),
|
||||
password: s.password,
|
||||
log: connLog,
|
||||
}
|
||||
sess.serve()
|
||||
}
|
||||
|
||||
// rejectConnection sends a minimal RFB handshake with a security failure
|
||||
// reason, so VNC clients display the error message instead of a generic
|
||||
// "unexpected disconnect."
|
||||
func rejectConnection(conn net.Conn, reason string) {
|
||||
defer conn.Close()
|
||||
// RFB 3.8 server version.
|
||||
io.WriteString(conn, "RFB 003.008\n")
|
||||
// Read client version (12 bytes), ignore errors.
|
||||
var clientVer [12]byte
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
io.ReadFull(conn, clientVer[:])
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
// Send 0 security types = connection failed, followed by reason.
|
||||
msg := []byte(reason)
|
||||
buf := make([]byte, 1+4+len(msg))
|
||||
buf[0] = 0 // 0 security types = failure
|
||||
binary.BigEndian.PutUint32(buf[1:5], uint32(len(msg)))
|
||||
copy(buf[5:], msg)
|
||||
conn.Write(buf)
|
||||
}
|
||||
|
||||
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
|
||||
|
||||
// authenticateJWT validates the JWT from the connection header and checks
|
||||
// authorization. For attach mode, just checks membership in the authorized
|
||||
// user list. For session mode, additionally validates the OS user mapping.
|
||||
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
|
||||
if header.jwt == "" {
|
||||
return "", fmt.Errorf("JWT required but not provided")
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
s.mu.Unlock()
|
||||
return "", fmt.Errorf("initialize JWT validator: %w", err)
|
||||
}
|
||||
validator := s.jwtValidator
|
||||
extractor := s.jwtExtractor
|
||||
s.mu.Unlock()
|
||||
|
||||
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("validate JWT: %w", err)
|
||||
}
|
||||
|
||||
if err := s.checkTokenAge(token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
userAuth, err := extractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("extract user from JWT: %w", err)
|
||||
}
|
||||
if userAuth.UserId == "" {
|
||||
return "", fmt.Errorf("JWT has no user ID")
|
||||
}
|
||||
|
||||
switch header.mode {
|
||||
case ModeSession:
|
||||
// Session mode: check user + OS username mapping.
|
||||
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
|
||||
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
|
||||
}
|
||||
default:
|
||||
// Attach mode: just check user is in the authorized list (wildcard OS user).
|
||||
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
|
||||
return "", fmt.Errorf("user not authorized for VNC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return userAuth.UserId, nil
|
||||
}
|
||||
|
||||
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
if s.jwtConfig == nil {
|
||||
return fmt.Errorf("no JWT config")
|
||||
}
|
||||
|
||||
s.jwtValidator = nbjwt.NewValidator(
|
||||
s.jwtConfig.Issuer,
|
||||
s.jwtConfig.Audiences,
|
||||
s.jwtConfig.KeysLocation,
|
||||
false,
|
||||
)
|
||||
|
||||
opts := []nbjwt.ClaimsExtractorOption{nbjwt.WithAudience(s.jwtConfig.Audiences[0])}
|
||||
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
|
||||
opts = append(opts, nbjwt.WithUserIDClaim(claim))
|
||||
}
|
||||
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token) error {
|
||||
maxAge := defaultJWTMaxTokenAge
|
||||
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
|
||||
maxAge = int(s.jwtConfig.MaxTokenAge)
|
||||
}
|
||||
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
|
||||
}
|
||||
|
||||
// readConnectionHeader reads the NetBird VNC session header from the connection.
|
||||
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
|
||||
//
|
||||
// [jwt_len: 2 bytes BE] [jwt: N bytes]
|
||||
//
|
||||
// Uses a short timeout: our WASM proxy sends the header immediately after
|
||||
// connecting. Standard VNC clients don't send anything first (server speaks
|
||||
// first in RFB), so they time out and get the default attach mode.
|
||||
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
|
||||
var hdr [3]byte
|
||||
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
|
||||
// Timeout or error: assume no header, use attach mode.
|
||||
return &connectionHeader{mode: ModeAttach}, nil
|
||||
}
|
||||
|
||||
// Restore a longer deadline for reading variable-length fields.
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
|
||||
mode := hdr[0]
|
||||
usernameLen := binary.BigEndian.Uint16(hdr[1:3])
|
||||
|
||||
var username string
|
||||
if usernameLen > 0 {
|
||||
if usernameLen > 256 {
|
||||
return nil, fmt.Errorf("username too long: %d", usernameLen)
|
||||
}
|
||||
buf := make([]byte, usernameLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("read username: %w", err)
|
||||
}
|
||||
username = string(buf)
|
||||
}
|
||||
|
||||
// Read JWT token length and data.
|
||||
var jwtLenBuf [2]byte
|
||||
var jwtToken string
|
||||
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
|
||||
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
|
||||
if jwtLen > 0 && jwtLen < 8192 {
|
||||
buf := make([]byte, jwtLen)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return nil, fmt.Errorf("read JWT: %w", err)
|
||||
}
|
||||
jwtToken = string(buf)
|
||||
}
|
||||
}
|
||||
|
||||
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
|
||||
var sessionID uint32
|
||||
var sidBuf [4]byte
|
||||
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
|
||||
sessionID = binary.BigEndian.Uint32(sidBuf[:])
|
||||
}
|
||||
|
||||
return &connectionHeader{mode: mode, username: username, jwt: jwtToken, sessionID: sessionID}, nil
|
||||
}
|
||||
15
client/vnc/server/server_darwin.go
Normal file
15
client/vnc/server/server_darwin.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {}
|
||||
|
||||
// serviceAcceptLoop is not supported on macOS.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on macOS, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
15
client/vnc/server/server_stub.go
Normal file
15
client/vnc/server/server_stub.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {}
|
||||
|
||||
// serviceAcceptLoop is not supported on non-Windows platforms.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on this platform, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
136
client/vnc/server/server_test.go
Normal file
136
client/vnc/server/server_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCapturer returns a 100x100 image for test sessions.
|
||||
type testCapturer struct{}
|
||||
|
||||
func (t *testCapturer) Width() int { return 100 }
|
||||
func (t *testCapturer) Height() int { return 100 }
|
||||
func (t *testCapturer) Capture() (*image.RGBA, error) { return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil }
|
||||
|
||||
func startTestServer(t *testing.T, disableAuth bool, jwtConfig *JWTConfig) (net.Addr, *Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, "")
|
||||
srv.SetDisableAuth(disableAuth)
|
||||
if jwtConfig != nil {
|
||||
srv.SetJWTConfig(jwtConfig)
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv
|
||||
}
|
||||
|
||||
func TestAuthEnabled_NoJWTConfig_RejectsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false, nil)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header: attach mode, no username, no JWT.
|
||||
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version then security failure.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version to proceed through handshake.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read security types: 0 means failure, followed by reason.
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), "identity provider", "rejection reason should mention missing IdP config")
|
||||
}
|
||||
|
||||
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, true, nil)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header: attach mode, no username, no JWT.
|
||||
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get security types (not 0 = failure).
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
||||
}
|
||||
|
||||
func TestAuthEnabled_EmptyJWT_Rejected(t *testing.T) {
|
||||
// Auth enabled with a (bogus) JWT config: connections without JWT should be rejected.
|
||||
addr, _ := startTestServer(t, false, &JWTConfig{
|
||||
Issuer: "https://example.com",
|
||||
KeysLocation: "https://example.com/.well-known/jwks.json",
|
||||
Audiences: []string{"test"},
|
||||
})
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Send session header with empty JWT.
|
||||
header := []byte{ModeAttach, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should reject with 0 security types")
|
||||
}
|
||||
223
client/vnc/server/server_windows.go
Normal file
223
client/vnc/server/server_windows.go
Normal file
@@ -0,0 +1,223 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
sasDLL = windows.NewLazySystemDLL("sas.dll")
|
||||
procSendSAS = sasDLL.NewProc("SendSAS")
|
||||
|
||||
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||
)
|
||||
|
||||
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
|
||||
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
|
||||
// local processes from triggering the Secure Attention Sequence.
|
||||
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
|
||||
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
|
||||
// to the interactive user (IU) so the VNC agent in the console session can
|
||||
// signal it. Other local users and network users are denied.
|
||||
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sd uintptr
|
||||
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
|
||||
uintptr(unsafe.Pointer(sddl)),
|
||||
1, // SDDL_REVISION_1
|
||||
uintptr(unsafe.Pointer(&sd)),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, lerr
|
||||
}
|
||||
return &windows.SecurityAttributes{
|
||||
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
|
||||
InheritHandle: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
|
||||
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
|
||||
// SendSAS silently does nothing on most Windows editions.
|
||||
func enableSoftwareSAS() {
|
||||
key, _, err := registry.CreateKey(
|
||||
registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||
registry.SET_VALUE,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
|
||||
log.Warnf("set SoftwareSASGeneration: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
|
||||
}
|
||||
|
||||
// startSASListener creates a named event with a restricted DACL and waits for
|
||||
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
|
||||
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
|
||||
// Only SYSTEM processes can open the event.
|
||||
func startSASListener() {
|
||||
enableSoftwareSAS()
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS listener UTF16: %v", err)
|
||||
return
|
||||
}
|
||||
sa, err := sasSecurityAttributes()
|
||||
if err != nil {
|
||||
log.Warnf("build SAS security descriptor: %v", err)
|
||||
return
|
||||
}
|
||||
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
|
||||
if err != nil {
|
||||
log.Warnf("SAS CreateEvent: %v", err)
|
||||
return
|
||||
}
|
||||
log.Info("SAS listener ready (Session 0)")
|
||||
go func() {
|
||||
defer windows.CloseHandle(ev)
|
||||
for {
|
||||
ret, _ := windows.WaitForSingleObject(ev, windows.INFINITE)
|
||||
if ret == windows.WAIT_OBJECT_0 {
|
||||
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
|
||||
if r == 0 {
|
||||
log.Warnf("SendSAS: %v", sasErr)
|
||||
} else {
|
||||
log.Info("SendSAS called from Session 0")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// enablePrivilege enables a named privilege on the current process token.
|
||||
func enablePrivilege(name string) error {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
|
||||
return err
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var luid windows.LUID
|
||||
namePtr, _ := windows.UTF16PtrFromString(name)
|
||||
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
|
||||
return err
|
||||
}
|
||||
tp := windows.Tokenprivileges{PrivilegeCount: 1}
|
||||
tp.Privileges[0].Luid = luid
|
||||
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// platformInit starts the SAS listener and enables privileges needed for
|
||||
// Session 0 operations (agent spawning, SendSAS).
|
||||
func (s *Server) platformInit() {
|
||||
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
|
||||
if err := enablePrivilege(priv); err != nil {
|
||||
log.Debugf("enable %s: %v", priv, err)
|
||||
}
|
||||
}
|
||||
startSASListener()
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in Session 0. It validates source IP and
|
||||
// authenticates via JWT before proxying connections to the user-session agent.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
|
||||
sm := newSessionManager(agentPort)
|
||||
go sm.run()
|
||||
|
||||
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%s", agentPort)
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
sm.Stop()
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.handleServiceConnection(conn, sm)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceConnection validates the source IP and JWT, then proxies
|
||||
// the connection (with header bytes replayed) to the agent.
|
||||
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if s.jwtConfig == nil {
|
||||
rejectConnection(conn, "auth enabled but no identity provider configured")
|
||||
connLog.Warn("auth rejected: no identity provider configured")
|
||||
return
|
||||
}
|
||||
if _, err := s.authenticateJWT(header); err != nil {
|
||||
rejectConnection(conn, fmt.Sprintf("auth: %v", err))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Replay buffered header bytes + remaining stream to the agent.
|
||||
replayConn := &prefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
proxyToAgent(replayConn, agentPort, sm.AuthToken())
|
||||
}
|
||||
|
||||
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
|
||||
type prefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *prefixConn) Read(b []byte) (int, error) {
|
||||
return p.Reader.Read(b)
|
||||
}
|
||||
15
client/vnc/server/server_x11.go
Normal file
15
client/vnc/server/server_x11.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {}
|
||||
|
||||
// serviceAcceptLoop is not supported on Linux.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return newSessionManager(s.log)
|
||||
}
|
||||
443
client/vnc/server/session.go
Normal file
443
client/vnc/server/session.go
Normal file
@@ -0,0 +1,443 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
readDeadline = 60 * time.Second
|
||||
maxCutTextBytes = 1 << 20 // 1 MiB
|
||||
)
|
||||
|
||||
const tileSize = 64 // pixels per tile for dirty-rect detection
|
||||
|
||||
type session struct {
|
||||
conn net.Conn
|
||||
capturer ScreenCapturer
|
||||
injector InputInjector
|
||||
serverW int
|
||||
serverH int
|
||||
password string
|
||||
log *log.Entry
|
||||
|
||||
writeMu sync.Mutex
|
||||
pf clientPixelFormat
|
||||
useZlib bool
|
||||
zlib *zlibState
|
||||
prevFrame *image.RGBA
|
||||
idleFrames int
|
||||
}
|
||||
|
||||
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
|
||||
|
||||
// serve runs the full RFB session lifecycle.
|
||||
func (s *session) serve() {
|
||||
defer s.conn.Close()
|
||||
s.pf = defaultClientPixelFormat()
|
||||
|
||||
if err := s.handshake(); err != nil {
|
||||
s.log.Warnf("handshake with %s: %v", s.addr(), err)
|
||||
return
|
||||
}
|
||||
s.log.Infof("client connected: %s", s.addr())
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go s.clipboardPoll(done)
|
||||
|
||||
if err := s.messageLoop(); err != nil && err != io.EOF {
|
||||
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
|
||||
} else {
|
||||
s.log.Infof("client disconnected: %s", s.addr())
|
||||
}
|
||||
}
|
||||
|
||||
// clipboardPoll periodically checks the server-side clipboard and sends
|
||||
// changes to the VNC client. Only runs during active sessions.
|
||||
func (s *session) clipboardPoll(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastClip string
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
text := s.injector.GetClipboard()
|
||||
if len(text) > maxCutTextBytes {
|
||||
text = text[:maxCutTextBytes]
|
||||
}
|
||||
if text != "" && text != lastClip {
|
||||
lastClip = text
|
||||
if err := s.sendServerCutText(text); err != nil {
|
||||
s.log.Debugf("send clipboard to client: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handshake() error {
|
||||
// Send protocol version.
|
||||
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
|
||||
return fmt.Errorf("send version: %w", err)
|
||||
}
|
||||
|
||||
// Read client version.
|
||||
var clientVer [12]byte
|
||||
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
|
||||
return fmt.Errorf("read client version: %w", err)
|
||||
}
|
||||
|
||||
// Send supported security types.
|
||||
if err := s.sendSecurityTypes(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read chosen security type.
|
||||
var secType [1]byte
|
||||
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
|
||||
return fmt.Errorf("read security type: %w", err)
|
||||
}
|
||||
|
||||
if err := s.handleSecurity(secType[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read ClientInit.
|
||||
var clientInit [1]byte
|
||||
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
|
||||
return fmt.Errorf("read ClientInit: %w", err)
|
||||
}
|
||||
|
||||
return s.sendServerInit()
|
||||
}
|
||||
|
||||
func (s *session) sendSecurityTypes() error {
|
||||
if s.password == "" {
|
||||
_, err := s.conn.Write([]byte{1, secNone})
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write([]byte{1, secVNCAuth})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) handleSecurity(secType byte) error {
|
||||
switch secType {
|
||||
case secVNCAuth:
|
||||
return s.doVNCAuth()
|
||||
case secNone:
|
||||
return binary.Write(s.conn, binary.BigEndian, uint32(0))
|
||||
default:
|
||||
return fmt.Errorf("unsupported security type: %d", secType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) doVNCAuth() error {
|
||||
challenge := make([]byte, 16)
|
||||
if _, err := rand.Read(challenge); err != nil {
|
||||
return fmt.Errorf("generate challenge: %w", err)
|
||||
}
|
||||
if _, err := s.conn.Write(challenge); err != nil {
|
||||
return fmt.Errorf("send challenge: %w", err)
|
||||
}
|
||||
|
||||
response := make([]byte, 16)
|
||||
if _, err := io.ReadFull(s.conn, response); err != nil {
|
||||
return fmt.Errorf("read auth response: %w", err)
|
||||
}
|
||||
|
||||
var result uint32
|
||||
if s.password != "" {
|
||||
expected := vncAuthEncrypt(challenge, s.password)
|
||||
if !bytes.Equal(expected, response) {
|
||||
result = 1
|
||||
}
|
||||
}
|
||||
|
||||
if err := binary.Write(s.conn, binary.BigEndian, result); err != nil {
|
||||
return fmt.Errorf("send auth result: %w", err)
|
||||
}
|
||||
if result != 0 {
|
||||
msg := "authentication failed"
|
||||
_ = binary.Write(s.conn, binary.BigEndian, uint32(len(msg)))
|
||||
_, _ = s.conn.Write([]byte(msg))
|
||||
return fmt.Errorf("authentication failed from %s", s.addr())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) sendServerInit() error {
|
||||
name := []byte("NetBird VNC")
|
||||
buf := make([]byte, 0, 4+16+4+len(name))
|
||||
|
||||
// Framebuffer width and height.
|
||||
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
|
||||
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
|
||||
|
||||
// Server pixel format.
|
||||
buf = append(buf, serverPixelFormat[:]...)
|
||||
|
||||
// Desktop name.
|
||||
buf = append(buf,
|
||||
byte(len(name)>>24), byte(len(name)>>16),
|
||||
byte(len(name)>>8), byte(len(name)),
|
||||
)
|
||||
buf = append(buf, name...)
|
||||
|
||||
_, err := s.conn.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) messageLoop() error {
|
||||
for {
|
||||
var msgType [1]byte
|
||||
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
|
||||
return fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = s.conn.SetDeadline(time.Time{})
|
||||
|
||||
switch msgType[0] {
|
||||
case clientSetPixelFormat:
|
||||
if err := s.handleSetPixelFormat(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientSetEncodings:
|
||||
if err := s.handleSetEncodings(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientFramebufferUpdateRequest:
|
||||
if err := s.handleFBUpdateRequest(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientKeyEvent:
|
||||
if err := s.handleKeyEvent(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientPointerEvent:
|
||||
if err := s.handlePointerEvent(); err != nil {
|
||||
return err
|
||||
}
|
||||
case clientCutText:
|
||||
if err := s.handleCutText(); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unknown client message type: %d", msgType[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleSetPixelFormat() error {
|
||||
var buf [19]byte // 3 padding + 16 pixel format
|
||||
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
|
||||
return fmt.Errorf("read SetPixelFormat: %w", err)
|
||||
}
|
||||
s.pf = parsePixelFormat(buf[3:19])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleSetEncodings() error {
|
||||
var header [3]byte // 1 padding + 2 number-of-encodings
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read SetEncodings header: %w", err)
|
||||
}
|
||||
numEnc := binary.BigEndian.Uint16(header[1:3])
|
||||
buf := make([]byte, int(numEnc)*4)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if client supports zlib encoding.
|
||||
for i := range int(numEnc) {
|
||||
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
|
||||
if enc == encZlib {
|
||||
s.useZlib = true
|
||||
if s.zlib == nil {
|
||||
s.zlib = newZlibState()
|
||||
}
|
||||
s.log.Debugf("client supports zlib encoding")
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleFBUpdateRequest() error {
|
||||
var req [9]byte
|
||||
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
|
||||
return fmt.Errorf("read FBUpdateRequest: %w", err)
|
||||
}
|
||||
incremental := req[0]
|
||||
|
||||
img, err := s.capturer.Capture()
|
||||
if err != nil {
|
||||
return fmt.Errorf("capture screen: %w", err)
|
||||
}
|
||||
|
||||
if incremental == 1 && s.prevFrame != nil {
|
||||
rects := diffRects(s.prevFrame, img, s.serverW, s.serverH, tileSize)
|
||||
if len(rects) == 0 {
|
||||
// Nothing changed. Back off briefly before responding to reduce
|
||||
// CPU usage when the screen is static. The client re-requests
|
||||
// immediately after receiving our empty response, so without
|
||||
// this delay we'd spin at ~1000fps checking for changes.
|
||||
s.idleFrames++
|
||||
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
s.savePrevFrame(img)
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
s.idleFrames = 0
|
||||
s.savePrevFrame(img)
|
||||
return s.sendDirtyRects(img, rects)
|
||||
}
|
||||
|
||||
// Full update.
|
||||
s.idleFrames = 0
|
||||
s.savePrevFrame(img)
|
||||
return s.sendFullUpdate(img)
|
||||
}
|
||||
|
||||
// savePrevFrame copies img's pixel data into prevFrame. This is necessary
|
||||
// because some capturers (DXGI) reuse the same image buffer across calls,
|
||||
// so a simple pointer assignment would make prevFrame alias the live buffer
|
||||
// and diffRects would always see zero changes.
|
||||
func (s *session) savePrevFrame(img *image.RGBA) {
|
||||
if s.prevFrame == nil || s.prevFrame.Rect != img.Rect {
|
||||
s.prevFrame = image.NewRGBA(img.Rect)
|
||||
}
|
||||
copy(s.prevFrame.Pix, img.Pix)
|
||||
}
|
||||
|
||||
// sendEmptyUpdate sends a FramebufferUpdate with zero rectangles.
|
||||
func (s *session) sendEmptyUpdate() error {
|
||||
var buf [4]byte
|
||||
buf[0] = serverFramebufferUpdate
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf[:])
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) sendFullUpdate(img *image.RGBA) error {
|
||||
w, h := s.serverW, s.serverH
|
||||
|
||||
var buf []byte
|
||||
if s.useZlib && s.zlib != nil {
|
||||
buf = encodeZlibRect(img, s.pf, 0, 0, w, h, s.zlib.w, s.zlib.buf)
|
||||
} else {
|
||||
buf = encodeRawRect(img, s.pf, 0, 0, w, h)
|
||||
}
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) sendDirtyRects(img *image.RGBA, rects [][4]int) error {
|
||||
// Build a multi-rectangle FramebufferUpdate.
|
||||
// Header: type(1) + padding(1) + numRects(2)
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], uint16(len(rects)))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, r := range rects {
|
||||
x, y, w, h := r[0], r[1], r[2], r[3]
|
||||
|
||||
var rectBuf []byte
|
||||
if s.useZlib && s.zlib != nil {
|
||||
rectBuf = encodeZlibRect(img, s.pf, x, y, w, h, s.zlib.w, s.zlib.buf)
|
||||
// encodeZlibRect includes its own FBUpdate header for 1 rect.
|
||||
// For multi-rect, we need just the rect data without the FBUpdate header.
|
||||
// Skip the 4-byte FBUpdate header since we already sent ours.
|
||||
rectBuf = rectBuf[4:]
|
||||
} else {
|
||||
rectBuf = encodeRawRect(img, s.pf, x, y, w, h)
|
||||
rectBuf = rectBuf[4:] // skip FBUpdate header
|
||||
}
|
||||
|
||||
if _, err := s.conn.Write(rectBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleKeyEvent() error {
|
||||
var data [7]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read KeyEvent: %w", err)
|
||||
}
|
||||
down := data[0] == 1
|
||||
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||
s.injector.InjectKey(keysym, down)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handlePointerEvent() error {
|
||||
var data [5]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read PointerEvent: %w", err)
|
||||
}
|
||||
buttonMask := data[0]
|
||||
x := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
y := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
s.injector.InjectPointer(buttonMask, x, y, s.serverW, s.serverH)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleCutText() error {
|
||||
var header [7]byte // 3 padding + 4 length
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read CutText header: %w", err)
|
||||
}
|
||||
length := binary.BigEndian.Uint32(header[3:7])
|
||||
if length > maxCutTextBytes {
|
||||
return fmt.Errorf("cut text too large: %d bytes", length)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read CutText payload: %w", err)
|
||||
}
|
||||
s.injector.SetClipboard(string(buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendServerCutText sends clipboard text from the server to the client.
|
||||
func (s *session) sendServerCutText(text string) error {
|
||||
data := []byte(text)
|
||||
buf := make([]byte, 8+len(data))
|
||||
buf[0] = serverCutText
|
||||
// buf[1:4] = padding (zero)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
|
||||
copy(buf[8:], data)
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
79
client/vnc/server/shutdown_state.go
Normal file
79
client/vnc/server/shutdown_state.go
Normal file
@@ -0,0 +1,79 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ShutdownState tracks VNC virtual session processes for crash recovery.
|
||||
// Persisted by the state manager; on restart, residual processes are killed.
|
||||
type ShutdownState struct {
|
||||
// Processes maps a description to its PID (e.g., "xvfb:50" -> 1234).
|
||||
Processes map[string]int `json:"processes,omitempty"`
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "vnc_sessions_state"
|
||||
}
|
||||
|
||||
// Cleanup kills any residual VNC session processes left from a crash.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
if len(s.Processes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for desc, pid := range s.Processes {
|
||||
if pid <= 0 {
|
||||
continue
|
||||
}
|
||||
if !isOurProcess(pid, desc) {
|
||||
log.Debugf("cleanup:skipping PID %d (%s), not ours", pid, desc)
|
||||
continue
|
||||
}
|
||||
log.Infof("cleanup:killing residual process %d (%s)", pid, desc)
|
||||
// Kill the process group (negative PID) to get children too.
|
||||
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
|
||||
// Try individual process if group kill fails.
|
||||
syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
s.Processes = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// isOurProcess verifies the PID still belongs to a VNC-related process
|
||||
// by checking /proc/<pid>/cmdline (Linux) or the process name.
|
||||
func isOurProcess(pid int, desc string) bool {
|
||||
// Check if the process exists at all.
|
||||
if err := syscall.Kill(pid, 0); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// On Linux, verify via /proc cmdline.
|
||||
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
||||
if err != nil {
|
||||
// No /proc (FreeBSD): trust the PID if the process exists.
|
||||
// PID reuse is unlikely in the short window between crash and restart.
|
||||
return true
|
||||
}
|
||||
|
||||
cmd := string(cmdline)
|
||||
// Match against expected process types.
|
||||
if strings.Contains(desc, "xvfb") || strings.Contains(desc, "xorg") {
|
||||
return strings.Contains(cmd, "Xvfb") || strings.Contains(cmd, "Xorg")
|
||||
}
|
||||
if strings.Contains(desc, "desktop") {
|
||||
return strings.Contains(cmd, "session") || strings.Contains(cmd, "plasma") ||
|
||||
strings.Contains(cmd, "gnome") || strings.Contains(cmd, "xfce") ||
|
||||
strings.Contains(cmd, "dbus-launch")
|
||||
}
|
||||
return false
|
||||
}
|
||||
37
client/vnc/server/stubs.go
Normal file
37
client/vnc/server/stubs.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
const maxCapturerRetries = 5
|
||||
|
||||
// StubCapturer is a placeholder for platforms without screen capture support.
|
||||
type StubCapturer struct{}
|
||||
|
||||
// Width returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Width() int { return 0 }
|
||||
|
||||
// Height returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Height() int { return 0 }
|
||||
|
||||
// Capture returns an error on unsupported platforms.
|
||||
func (c *StubCapturer) Capture() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("screen capture not supported on this platform")
|
||||
}
|
||||
|
||||
// StubInputInjector is a placeholder for platforms without input injection support.
|
||||
type StubInputInjector struct{}
|
||||
|
||||
// InjectKey is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectKey(_ uint32, _ bool) {}
|
||||
|
||||
// InjectPointer is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectPointer(_ uint8, _, _, _, _ int) {}
|
||||
|
||||
// SetClipboard is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) SetClipboard(_ string) {}
|
||||
|
||||
// GetClipboard returns empty on unsupported platforms.
|
||||
func (s *StubInputInjector) GetClipboard() string { return "" }
|
||||
19
client/vnc/server/swizzle_windows.go
Normal file
19
client/vnc/server/swizzle_windows.go
Normal file
@@ -0,0 +1,19 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// swizzleBGRAtoRGBA swaps B and R channels in a BGRA pixel buffer in-place.
|
||||
// Operates on uint32 words for throughput: one read-modify-write per pixel.
|
||||
func swizzleBGRAtoRGBA(pix []byte) {
|
||||
n := len(pix) / 4
|
||||
pixels := unsafe.Slice((*uint32)(unsafe.Pointer(&pix[0])), n)
|
||||
for i := range n {
|
||||
p := pixels[i]
|
||||
// p = 0xAABBGGRR (little-endian BGRA in memory: B,G,R,A bytes)
|
||||
// We want 0xAABBGGRR -> 0xAARRGGBB (RGBA in memory: R,G,B,A bytes)
|
||||
// Swap byte 0 (B) and byte 2 (R), keep byte 1 (G) and byte 3 (A).
|
||||
pixels[i] = (p & 0xFF00FF00) | ((p & 0x00FF0000) >> 16) | ((p & 0x000000FF) << 16)
|
||||
}
|
||||
}
|
||||
634
client/vnc/server/virtual_x11.go
Normal file
634
client/vnc/server/virtual_x11.go
Normal file
@@ -0,0 +1,634 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VirtualSession manages a virtual X11 display (Xvfb) with a desktop session
|
||||
// running as a target user. It implements ScreenCapturer and InputInjector by
|
||||
// delegating to an X11Capturer/X11InputInjector pointed at the virtual display.
|
||||
const sessionIdleTimeout = 5 * time.Minute
|
||||
|
||||
type VirtualSession struct {
|
||||
mu sync.Mutex
|
||||
display string
|
||||
user *user.User
|
||||
uid uint32
|
||||
gid uint32
|
||||
groups []uint32
|
||||
xvfb *exec.Cmd
|
||||
desktop *exec.Cmd
|
||||
poller *X11Poller
|
||||
injector *X11InputInjector
|
||||
log *log.Entry
|
||||
stopped bool
|
||||
clients int
|
||||
idleTimer *time.Timer
|
||||
onIdle func() // called when idle timeout fires or Xvfb dies
|
||||
}
|
||||
|
||||
// StartVirtualSession creates and starts a virtual X11 session for the given user.
|
||||
// Requires root privileges to create sessions as other users.
|
||||
func StartVirtualSession(username string, logger *log.Entry) (*VirtualSession, error) {
|
||||
if os.Getuid() != 0 {
|
||||
return nil, fmt.Errorf("virtual sessions require root privileges")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("Xvfb"); err != nil {
|
||||
if _, err := exec.LookPath("Xorg"); err != nil {
|
||||
return nil, fmt.Errorf("neither Xvfb nor Xorg found (install xvfb or xserver-xorg)")
|
||||
}
|
||||
if !hasDummyDriver() {
|
||||
return nil, fmt.Errorf("Xvfb not found and Xorg dummy driver not installed (install xvfb or xf86-video-dummy)")
|
||||
}
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse uid: %w", err)
|
||||
}
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse gid: %w", err)
|
||||
}
|
||||
|
||||
groups, err := supplementaryGroups(u)
|
||||
if err != nil {
|
||||
logger.Debugf("supplementary groups for %s: %v", username, err)
|
||||
}
|
||||
|
||||
vs := &VirtualSession{
|
||||
user: u,
|
||||
uid: uint32(uid),
|
||||
gid: uint32(gid),
|
||||
groups: groups,
|
||||
log: logger.WithField("vnc_user", username),
|
||||
}
|
||||
|
||||
if err := vs.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) start() error {
|
||||
display, err := findFreeDisplay()
|
||||
if err != nil {
|
||||
return fmt.Errorf("find free display: %w", err)
|
||||
}
|
||||
vs.display = display
|
||||
|
||||
if err := vs.startXvfb(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", vs.display[1:])
|
||||
if err := waitForPath(socketPath, 5*time.Second); err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("wait for X11 socket %s: %w", socketPath, err)
|
||||
}
|
||||
|
||||
// Grant the target user access to the display via xhost.
|
||||
xhostCmd := exec.Command("xhost", "+SI:localuser:"+vs.user.Username)
|
||||
xhostCmd.Env = []string{"DISPLAY=" + vs.display}
|
||||
if out, err := xhostCmd.CombinedOutput(); err != nil {
|
||||
vs.log.Debugf("xhost: %s (%v)", strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
|
||||
vs.poller = NewX11Poller(vs.display)
|
||||
|
||||
injector, err := NewX11InputInjector(vs.display)
|
||||
if err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("create X11 injector for %s: %w", vs.display, err)
|
||||
}
|
||||
vs.injector = injector
|
||||
|
||||
if err := vs.startDesktop(); err != nil {
|
||||
vs.injector.Close()
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("start desktop: %w", err)
|
||||
}
|
||||
|
||||
vs.log.Infof("virtual session started: display=%s user=%s", vs.display, vs.user.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConnect increments the client count and cancels any idle timer.
|
||||
func (vs *VirtualSession) ClientConnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients++
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the client count. When the last client
|
||||
// disconnects, starts an idle timer that destroys the session.
|
||||
func (vs *VirtualSession) ClientDisconnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients--
|
||||
if vs.clients <= 0 {
|
||||
vs.clients = 0
|
||||
vs.log.Infof("no VNC clients connected, session will be destroyed in %s", sessionIdleTimeout)
|
||||
vs.idleTimer = time.AfterFunc(sessionIdleTimeout, vs.idleExpired)
|
||||
}
|
||||
}
|
||||
|
||||
// idleExpired is called by the idle timer. It stops the session and
|
||||
// notifies the session manager via onIdle so it removes us from the map.
|
||||
func (vs *VirtualSession) idleExpired() {
|
||||
vs.log.Info("idle timeout reached, destroying virtual session")
|
||||
vs.Stop()
|
||||
// onIdle acquires sessionManager.mu; safe because Stop() has released vs.mu.
|
||||
if vs.onIdle != nil {
|
||||
vs.onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
// isAlive returns true if the session is running and its X server socket exists.
|
||||
func (vs *VirtualSession) isAlive() bool {
|
||||
vs.mu.Lock()
|
||||
stopped := vs.stopped
|
||||
display := vs.display
|
||||
vs.mu.Unlock()
|
||||
|
||||
if stopped {
|
||||
return false
|
||||
}
|
||||
// Verify the X socket still exists on disk.
|
||||
socketPath := fmt.Sprintf("/tmp/.X11-unix/X%s", display[1:])
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Capturer returns the screen capturer for this virtual session.
|
||||
func (vs *VirtualSession) Capturer() ScreenCapturer {
|
||||
return vs.poller
|
||||
}
|
||||
|
||||
// Injector returns the input injector for this virtual session.
|
||||
func (vs *VirtualSession) Injector() InputInjector {
|
||||
return vs.injector
|
||||
}
|
||||
|
||||
// Display returns the X11 display string (e.g., ":99").
|
||||
func (vs *VirtualSession) Display() string {
|
||||
return vs.display
|
||||
}
|
||||
|
||||
// Stop terminates the virtual session, killing the desktop and Xvfb.
|
||||
func (vs *VirtualSession) Stop() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
|
||||
if vs.stopped {
|
||||
return
|
||||
}
|
||||
vs.stopped = true
|
||||
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
|
||||
vs.stopDesktop()
|
||||
vs.stopXvfb()
|
||||
|
||||
vs.log.Info("virtual session stopped")
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfb() error {
|
||||
if _, err := exec.LookPath("Xvfb"); err == nil {
|
||||
return vs.startXvfbDirect()
|
||||
}
|
||||
return vs.startXorgDummy()
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfbDirect() error {
|
||||
vs.xvfb = exec.Command("Xvfb", vs.display,
|
||||
"-screen", "0", "1280x800x24",
|
||||
"-ac",
|
||||
"-nolisten", "tcp",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
return fmt.Errorf("start Xvfb on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xvfb started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go vs.monitorXvfb()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startXorgDummy starts Xorg with the dummy video driver as a fallback when
|
||||
// Xvfb is not installed. Most systems with a desktop have Xorg available.
|
||||
func (vs *VirtualSession) startXorgDummy() error {
|
||||
confPath := fmt.Sprintf("/tmp/nbvnc-dummy-%s.conf", vs.display[1:])
|
||||
conf := `Section "Device"
|
||||
Identifier "dummy"
|
||||
Driver "dummy"
|
||||
VideoRam 256000
|
||||
EndSection
|
||||
Section "Screen"
|
||||
Identifier "screen"
|
||||
Device "dummy"
|
||||
DefaultDepth 24
|
||||
SubSection "Display"
|
||||
Depth 24
|
||||
Modes "1280x800"
|
||||
EndSubSection
|
||||
EndSection
|
||||
`
|
||||
if err := os.WriteFile(confPath, []byte(conf), 0644); err != nil {
|
||||
return fmt.Errorf("write Xorg dummy config: %w", err)
|
||||
}
|
||||
|
||||
vs.xvfb = exec.Command("Xorg", vs.display,
|
||||
"-config", confPath,
|
||||
"-noreset",
|
||||
"-nolisten", "tcp",
|
||||
"-ac",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("start Xorg dummy on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xorg (dummy driver) started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go func() {
|
||||
vs.monitorXvfb()
|
||||
os.Remove(confPath)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorXvfb waits for the Xvfb/Xorg process to exit. If it exits
|
||||
// unexpectedly (not via Stop), the session is marked as dead and the
|
||||
// onIdle callback fires so the session manager removes it from the map.
|
||||
// The next GetOrCreate call for this user will create a fresh session.
|
||||
func (vs *VirtualSession) monitorXvfb() {
|
||||
if err := vs.xvfb.Wait(); err != nil {
|
||||
vs.log.Debugf("X server exited: %v", err)
|
||||
}
|
||||
|
||||
vs.mu.Lock()
|
||||
alreadyStopped := vs.stopped
|
||||
if !alreadyStopped {
|
||||
vs.log.Warn("X server exited unexpectedly, marking session as dead")
|
||||
vs.stopped = true
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
vs.stopDesktop()
|
||||
}
|
||||
onIdle := vs.onIdle
|
||||
vs.mu.Unlock()
|
||||
|
||||
if !alreadyStopped && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopXvfb() {
|
||||
if vs.xvfb != nil && vs.xvfb.Process != nil {
|
||||
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGTERM)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startDesktop() error {
|
||||
session := detectDesktopSession()
|
||||
|
||||
// Wrap the desktop command with dbus-launch to provide a session bus.
|
||||
// Without this, most desktop environments (XFCE, MATE, etc.) fail immediately.
|
||||
var args []string
|
||||
if _, err := exec.LookPath("dbus-launch"); err == nil {
|
||||
args = append([]string{"dbus-launch", "--exit-with-session"}, session...)
|
||||
} else {
|
||||
args = session
|
||||
}
|
||||
|
||||
vs.desktop = exec.Command(args[0], args[1:]...)
|
||||
vs.desktop.Dir = vs.user.HomeDir
|
||||
vs.desktop.Env = vs.buildUserEnv()
|
||||
vs.desktop.SysProcAttr = &syscall.SysProcAttr{
|
||||
Credential: &syscall.Credential{
|
||||
Uid: vs.uid,
|
||||
Gid: vs.gid,
|
||||
Groups: vs.groups,
|
||||
},
|
||||
Setsid: true,
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
if err := vs.desktop.Start(); err != nil {
|
||||
return fmt.Errorf("start desktop session (%v): %w", args, err)
|
||||
}
|
||||
vs.log.Infof("desktop session started: %v (pid=%d)", args, vs.desktop.Process.Pid)
|
||||
|
||||
go func() {
|
||||
if err := vs.desktop.Wait(); err != nil {
|
||||
vs.log.Debugf("desktop session exited: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopDesktop() {
|
||||
if vs.desktop != nil && vs.desktop.Process != nil {
|
||||
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGTERM)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) buildUserEnv() []string {
|
||||
return []string{
|
||||
"DISPLAY=" + vs.display,
|
||||
"HOME=" + vs.user.HomeDir,
|
||||
"USER=" + vs.user.Username,
|
||||
"LOGNAME=" + vs.user.Username,
|
||||
"SHELL=" + getUserShell(vs.user.Uid),
|
||||
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"XDG_RUNTIME_DIR=/run/user/" + vs.user.Uid,
|
||||
"DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/" + vs.user.Uid + "/bus",
|
||||
}
|
||||
}
|
||||
|
||||
// detectDesktopSession discovers available desktop sessions from the standard
|
||||
// /usr/share/xsessions/*.desktop files (FreeDesktop standard, used by all
|
||||
// display managers). Falls back to a hardcoded list if no .desktop files found.
|
||||
func detectDesktopSession() []string {
|
||||
// Scan xsessions directories (Linux: /usr/share, FreeBSD: /usr/local/share).
|
||||
for _, dir := range []string{"/usr/share/xsessions", "/usr/local/share/xsessions"} {
|
||||
if cmd := findXSession(dir); cmd != nil {
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try common session commands directly.
|
||||
fallbacks := [][]string{
|
||||
{"startplasma-x11"},
|
||||
{"gnome-session"},
|
||||
{"xfce4-session"},
|
||||
{"mate-session"},
|
||||
{"cinnamon-session"},
|
||||
{"openbox-session"},
|
||||
{"xterm"},
|
||||
}
|
||||
for _, s := range fallbacks {
|
||||
if _, err := exec.LookPath(s[0]); err == nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return []string{"xterm"}
|
||||
}
|
||||
|
||||
// sessionPriority defines preference order for desktop environments.
|
||||
// Lower number = higher priority. Unknown sessions get 100.
|
||||
var sessionPriority = map[string]int{
|
||||
"plasma": 1, // KDE
|
||||
"gnome": 2,
|
||||
"xfce": 3,
|
||||
"mate": 4,
|
||||
"cinnamon": 5,
|
||||
"lxqt": 6,
|
||||
"lxde": 7,
|
||||
"budgie": 8,
|
||||
"openbox": 20,
|
||||
"fluxbox": 21,
|
||||
"i3": 22,
|
||||
"xinit": 50, // generic user session
|
||||
"lightdm": 50,
|
||||
"default": 50,
|
||||
}
|
||||
|
||||
func findXSession(dir string) []string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type candidate struct {
|
||||
cmd string
|
||||
priority int
|
||||
}
|
||||
var candidates []candidate
|
||||
|
||||
for _, e := range entries {
|
||||
if !strings.HasSuffix(e.Name(), ".desktop") {
|
||||
continue
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, e.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
execCmd := ""
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "Exec=") {
|
||||
execCmd = strings.TrimSpace(strings.TrimPrefix(line, "Exec="))
|
||||
break
|
||||
}
|
||||
}
|
||||
if execCmd == "" || execCmd == "default" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Determine priority from the filename or exec command.
|
||||
pri := 100
|
||||
lower := strings.ToLower(e.Name() + " " + execCmd)
|
||||
for keyword, p := range sessionPriority {
|
||||
if strings.Contains(lower, keyword) && p < pri {
|
||||
pri = p
|
||||
}
|
||||
}
|
||||
candidates = append(candidates, candidate{cmd: execCmd, priority: pri})
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pick the highest priority (lowest number).
|
||||
best := candidates[0]
|
||||
for _, c := range candidates[1:] {
|
||||
if c.priority < best.priority {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the binary exists.
|
||||
parts := strings.Fields(best.cmd)
|
||||
if _, err := exec.LookPath(parts[0]); err != nil {
|
||||
return nil
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// findFreeDisplay scans for an unused X11 display number.
|
||||
func findFreeDisplay() (string, error) {
|
||||
for n := 50; n < 200; n++ {
|
||||
lockFile := fmt.Sprintf("/tmp/.X%d-lock", n)
|
||||
socketFile := fmt.Sprintf("/tmp/.X11-unix/X%d", n)
|
||||
if _, err := os.Stat(lockFile); err == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socketFile); err == nil {
|
||||
continue
|
||||
}
|
||||
return fmt.Sprintf(":%d", n), nil
|
||||
}
|
||||
return "", fmt.Errorf("no free X11 display found (checked :50-:199)")
|
||||
}
|
||||
|
||||
// waitForPath polls until a filesystem path exists or the timeout expires.
|
||||
func waitForPath(path string, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for %s", path)
|
||||
}
|
||||
|
||||
// getUserShell returns the login shell for the given UID.
|
||||
func getUserShell(uid string) string {
|
||||
data, err := os.ReadFile("/etc/passwd")
|
||||
if err != nil {
|
||||
return "/bin/sh"
|
||||
}
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 7 && fields[2] == uid {
|
||||
return fields[6]
|
||||
}
|
||||
}
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
// supplementaryGroups returns the supplementary group IDs for a user.
|
||||
func supplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
gids, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var groups []uint32
|
||||
for _, g := range gids {
|
||||
id, err := strconv.ParseUint(g, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, uint32(id))
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// sessionManager tracks active virtual sessions by username.
|
||||
type sessionManager struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*VirtualSession
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
func newSessionManager(logger *log.Entry) *sessionManager {
|
||||
return &sessionManager{
|
||||
sessions: make(map[string]*VirtualSession),
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreate returns an existing virtual session or creates a new one.
|
||||
// If a previous session for this user is stopped or its X server died, it is replaced.
|
||||
func (sm *sessionManager) GetOrCreate(username string) (vncSession, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if vs, ok := sm.sessions[username]; ok {
|
||||
if vs.isAlive() {
|
||||
return vs, nil
|
||||
}
|
||||
sm.log.Infof("replacing dead virtual session for %s", username)
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
}
|
||||
|
||||
vs, err := StartVirtualSession(username, sm.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vs.onIdle = func() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
if cur, ok := sm.sessions[username]; ok && cur == vs {
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("removed idle virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
sm.sessions[username] = vs
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
// hasDummyDriver checks common paths for the Xorg dummy video driver.
|
||||
func hasDummyDriver() bool {
|
||||
paths := []string{
|
||||
"/usr/lib/xorg/modules/drivers/dummy_drv.so", // Debian/Ubuntu
|
||||
"/usr/lib64/xorg/modules/drivers/dummy_drv.so", // RHEL/Fedora
|
||||
"/usr/local/lib/xorg/modules/drivers/dummy_drv.so", // FreeBSD
|
||||
"/usr/lib/x86_64-linux-gnu/xorg/modules/drivers/dummy_drv.so", // Debian multiarch
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StopAll terminates all active virtual sessions.
|
||||
func (sm *sessionManager) StopAll() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for username, vs := range sm.sessions {
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("stopped virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
|
||||
174
client/vnc/testpage/index.html
Normal file
174
client/vnc/testpage/index.html
Normal file
@@ -0,0 +1,174 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>VNC Test</title>
|
||||
<style>
|
||||
body { margin: 0; background: #111; color: #eee; font-family: monospace; font-size: 13px; }
|
||||
#toolbar { position: fixed; top: 0; left: 0; right: 0; z-index: 10; background: #222; padding: 4px 8px; display: flex; gap: 8px; align-items: center; }
|
||||
#toolbar button { padding: 4px 12px; cursor: pointer; background: #444; color: #eee; border: 1px solid #666; border-radius: 3px; }
|
||||
#toolbar button:hover { background: #555; }
|
||||
#toolbar #status { flex: 1; }
|
||||
#vnc-container { width: 100vw; height: calc(100vh - 28px); margin-top: 28px; }
|
||||
#log { position: fixed; bottom: 0; left: 0; right: 0; max-height: 150px; overflow-y: auto; background: rgba(0,0,0,0.85); padding: 4px 8px; font-size: 11px; z-index: 10; display: none; }
|
||||
#log.visible { display: block; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="toolbar">
|
||||
<span id="status">Loading WASM...</span>
|
||||
<button onclick="sendCAD()">Ctrl+Alt+Del</button>
|
||||
<button onclick="document.getElementById('log').classList.toggle('visible')">Log</button>
|
||||
</div>
|
||||
<div id="vnc-container"></div>
|
||||
<div id="log"></div>
|
||||
|
||||
<script>
|
||||
const params = new URLSearchParams(location.search);
|
||||
const HOST = params.get('host') || '';
|
||||
const PORT = params.get('port') || '5900';
|
||||
const MODE = params.get('mode') || 'attach'; // 'attach' or 'session'
|
||||
const USER = params.get('user') || '';
|
||||
const SETUP_KEY = params.get('setup_key') || '64BB8FF4-5A96-488F-B0AE-316555E916B0';
|
||||
const MGMT_URL = params.get('mgmt') || 'http://192.168.100.1:8080';
|
||||
|
||||
const statusEl = document.getElementById('status');
|
||||
const logEl = document.getElementById('log');
|
||||
function addLog(msg) {
|
||||
const line = document.createElement('div');
|
||||
line.textContent = `[${new Date().toISOString().slice(11,23)}] ${msg}`;
|
||||
logEl.appendChild(line);
|
||||
logEl.scrollTop = logEl.scrollHeight;
|
||||
console.log('[vnc-test]', msg);
|
||||
}
|
||||
function setStatus(s) { statusEl.textContent = s; addLog(s); }
|
||||
|
||||
let rfbInstance = null;
|
||||
window.sendCAD = () => { if (rfbInstance) { rfbInstance.sendCtrlAltDel(); addLog('Sent Ctrl+Alt+Del'); } };
|
||||
|
||||
// VNC WebSocket proxy (bridges noVNC WebSocket API to Go WASM tunnel)
|
||||
class VNCProxyWS extends EventTarget {
|
||||
constructor(url) {
|
||||
super();
|
||||
this.url = url;
|
||||
this.readyState = 0;
|
||||
this.protocol = '';
|
||||
this.extensions = '';
|
||||
this.bufferedAmount = 0;
|
||||
this.binaryType = 'arraybuffer';
|
||||
this.onopen = null; this.onclose = null; this.onerror = null; this.onmessage = null;
|
||||
const match = url.match(/vnc\.proxy\.local\/(.+)/);
|
||||
this._proxyID = match ? match[1] : 'default';
|
||||
setTimeout(() => this._connect(), 0);
|
||||
}
|
||||
get CONNECTING() { return 0; } get OPEN() { return 1; } get CLOSING() { return 2; } get CLOSED() { return 3; }
|
||||
_connect() {
|
||||
try {
|
||||
const handler = window[`handleVNCWebSocket_${this._proxyID}`];
|
||||
if (!handler) throw new Error(`No VNC handler for ${this._proxyID}`);
|
||||
handler(this);
|
||||
this.readyState = 1;
|
||||
const ev = new Event('open');
|
||||
if (this.onopen) this.onopen(ev);
|
||||
this.dispatchEvent(ev);
|
||||
} catch (err) {
|
||||
addLog(`WS proxy error: ${err.message}`);
|
||||
this.readyState = 3;
|
||||
}
|
||||
}
|
||||
receiveFromGo(data) {
|
||||
const ev = new MessageEvent('message', { data });
|
||||
if (this.onmessage) this.onmessage(ev);
|
||||
this.dispatchEvent(ev);
|
||||
}
|
||||
send(data) {
|
||||
if (this.readyState !== 1) return;
|
||||
let u8;
|
||||
if (data instanceof ArrayBuffer) u8 = new Uint8Array(data);
|
||||
else if (data instanceof Uint8Array) u8 = data;
|
||||
else if (typeof data === 'string') u8 = new TextEncoder().encode(data);
|
||||
else if (data.buffer) u8 = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
||||
else return;
|
||||
if (this.onGoMessage) this.onGoMessage(u8);
|
||||
}
|
||||
close(code, reason) {
|
||||
if (this.readyState >= 2) return;
|
||||
this.readyState = 2;
|
||||
if (this.onGoClose) this.onGoClose();
|
||||
setTimeout(() => {
|
||||
this.readyState = 3;
|
||||
const ev = new CloseEvent('close', { code: code||1000, reason: reason||'', wasClean: true });
|
||||
if (this.onclose) this.onclose(ev);
|
||||
this.dispatchEvent(ev);
|
||||
}, 0);
|
||||
}
|
||||
}
|
||||
|
||||
async function main() {
|
||||
if (!HOST) { setStatus('Usage: ?host=<peer_ip>&setup_key=<key>[&mode=session&user=alice]'); return; }
|
||||
|
||||
// Install WS proxy before anything creates WebSockets
|
||||
const OrigWS = window.WebSocket;
|
||||
window.WebSocket = new Proxy(OrigWS, {
|
||||
construct(target, args) {
|
||||
if (args[0] && args[0].includes('vnc.proxy.local')) return new VNCProxyWS(args[0]);
|
||||
return new target(args[0], args[1]);
|
||||
}
|
||||
});
|
||||
|
||||
// Load WASM
|
||||
setStatus('Loading WASM runtime...');
|
||||
await new Promise((resolve, reject) => {
|
||||
const s = document.createElement('script');
|
||||
s.src = '/wasm_exec.js'; s.onload = resolve; s.onerror = reject;
|
||||
document.head.appendChild(s);
|
||||
});
|
||||
|
||||
setStatus('Loading NetBird WASM...');
|
||||
const go = new Go();
|
||||
const wasm = await WebAssembly.instantiateStreaming(fetch('/netbird.wasm'), go.importObject);
|
||||
go.run(wasm.instance);
|
||||
const t0 = Date.now();
|
||||
while (!window.NetBirdClient && Date.now() - t0 < 10000) await new Promise(r => setTimeout(r, 100));
|
||||
if (!window.NetBirdClient) { setStatus('WASM init timeout'); return; }
|
||||
addLog('WASM ready');
|
||||
|
||||
// Connect NetBird with setup key
|
||||
setStatus('Connecting NetBird...');
|
||||
let client;
|
||||
try {
|
||||
client = await window.NetBirdClient({
|
||||
setupKey: SETUP_KEY,
|
||||
managementURL: MGMT_URL,
|
||||
logLevel: 'debug',
|
||||
});
|
||||
addLog('Client created, starting...');
|
||||
await client.start();
|
||||
addLog('NetBird connected');
|
||||
} catch (err) {
|
||||
setStatus('NetBird error: ' + (err && err.message ? err.message : String(err)));
|
||||
return;
|
||||
}
|
||||
|
||||
// Create VNC proxy
|
||||
setStatus(`Creating VNC proxy (mode=${MODE}${USER ? ', user=' + USER : ''})...`);
|
||||
const proxyURL = await client.createVNCProxy(HOST, PORT, MODE, USER);
|
||||
addLog(`Proxy: ${proxyURL}`);
|
||||
|
||||
// Connect noVNC
|
||||
setStatus('Connecting VNC...');
|
||||
const { default: RFB } = await import('/novnc-pkg/core/rfb.js');
|
||||
const container = document.getElementById('vnc-container');
|
||||
rfbInstance = new RFB(container, proxyURL, { wsProtocols: [] });
|
||||
rfbInstance.scaleViewport = true;
|
||||
rfbInstance.resizeSession = false;
|
||||
rfbInstance.showDotCursor = true;
|
||||
rfbInstance.addEventListener('connect', () => setStatus(`Connected: ${HOST}`));
|
||||
rfbInstance.addEventListener('disconnect', e => setStatus(`Disconnected${e.detail?.clean ? '' : ' (unexpected)'}`));
|
||||
rfbInstance.addEventListener('credentialsrequired', () => rfbInstance.sendCredentials({ password: '' }));
|
||||
window.rfb = rfbInstance;
|
||||
}
|
||||
|
||||
main().catch(err => { setStatus(`Error: ${err.message}`); console.error(err); });
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
44
client/vnc/testpage/serve.go
Normal file
44
client/vnc/testpage/serve.go
Normal file
@@ -0,0 +1,44 @@
|
||||
//go:build ignore
|
||||
|
||||
// Simple file server for the VNC test page.
|
||||
// Usage: go run serve.go
|
||||
// Then open: http://localhost:9090?host=100.0.23.250
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Serve from the dashboard's public dir (has wasm, noVNC, etc.)
|
||||
dashboardPublic := os.Getenv("DASHBOARD_PUBLIC")
|
||||
if dashboardPublic == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
dashboardPublic = filepath.Join(home, "dev", "dashboard", "public")
|
||||
}
|
||||
|
||||
// Serve test page index.html from this directory
|
||||
testDir, _ := os.Getwd()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
// Test page itself
|
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/" || r.URL.Path == "/index.html" {
|
||||
http.ServeFile(w, r, filepath.Join(testDir, "index.html"))
|
||||
return
|
||||
}
|
||||
// Everything else from dashboard public (wasm, noVNC, etc.)
|
||||
http.FileServer(http.Dir(dashboardPublic)).ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
addr := ":9090"
|
||||
fmt.Printf("VNC test page: http://localhost%s?host=<peer_ip>\n", addr)
|
||||
fmt.Printf("Serving assets from: %s\n", dashboardPublic)
|
||||
if err := http.ListenAndServe(addr, mux); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "listen: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,8 @@ import (
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/vnc"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -317,8 +317,13 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createRDPProxyMethod creates the RDP proxy method
|
||||
func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
|
||||
// JS signature: createVNCProxy(hostname, port, mode?, username?, jwt?, sessionID?)
|
||||
// mode: "attach" (default) or "session"
|
||||
// username: required when mode is "session"
|
||||
// jwt: authentication token (from OIDC session)
|
||||
// sessionID: Windows session ID (0 = console/auto)
|
||||
func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
@@ -335,8 +340,25 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
mode := "attach"
|
||||
username := ""
|
||||
jwtToken := ""
|
||||
var sessionID uint32
|
||||
if len(args) > 2 && args[2].Type() == js.TypeString {
|
||||
mode = args[2].String()
|
||||
}
|
||||
if len(args) > 3 && args[3].Type() == js.TypeString {
|
||||
username = args[3].String()
|
||||
}
|
||||
if len(args) > 4 && args[4].Type() == js.TypeString {
|
||||
jwtToken = args[4].String()
|
||||
}
|
||||
if len(args) > 5 && args[5].Type() == js.TypeNumber {
|
||||
sessionID = uint32(args[5].Int())
|
||||
}
|
||||
|
||||
proxy := vnc.NewVNCProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String(), mode, username, jwtToken, sessionID)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -515,7 +537,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["createVNCProxy"] = createVNCProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
certValidationTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, certChain [][]byte) (bool, error) {
|
||||
if !conn.wsHandlers.Get("onCertificateRequest").Truthy() {
|
||||
return false, fmt.Errorf("certificate validation handler not configured")
|
||||
}
|
||||
|
||||
certInfo := js.Global().Get("Object").New()
|
||||
certInfo.Set("ServerAddr", conn.destination)
|
||||
|
||||
certArray := js.Global().Get("Array").New()
|
||||
for i, certBytes := range certChain {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(certBytes))
|
||||
js.CopyBytesToJS(uint8Array, certBytes)
|
||||
certArray.SetIndex(i, uint8Array)
|
||||
}
|
||||
certInfo.Set("ServerCertChain", certArray)
|
||||
if len(certChain) > 0 {
|
||||
cert, err := x509.ParseCertificate(certChain[0])
|
||||
if err == nil {
|
||||
info := js.Global().Get("Object").New()
|
||||
info.Set("subject", cert.Subject.String())
|
||||
info.Set("issuer", cert.Issuer.String())
|
||||
info.Set("validFrom", cert.NotBefore.Format(time.RFC3339))
|
||||
info.Set("validTo", cert.NotAfter.Format(time.RFC3339))
|
||||
info.Set("serialNumber", cert.SerialNumber.String())
|
||||
certInfo.Set("CertificateInfo", info)
|
||||
}
|
||||
}
|
||||
|
||||
promise := conn.wsHandlers.Call("onCertificateRequest", certInfo)
|
||||
|
||||
resultChan := make(chan bool)
|
||||
errorChan := make(chan error)
|
||||
|
||||
promise.Call("then", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
result := args[0].Bool()
|
||||
resultChan <- result
|
||||
return nil
|
||||
})).Call("catch", js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
errorChan <- fmt.Errorf("certificate validation failed")
|
||||
return nil
|
||||
}))
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result {
|
||||
log.Info("Certificate accepted by user")
|
||||
} else {
|
||||
log.Info("Certificate rejected by user")
|
||||
}
|
||||
return result, nil
|
||||
case err := <-errorChan:
|
||||
return false, err
|
||||
case <-time.After(certValidationTimeout):
|
||||
return false, fmt.Errorf("certificate validation timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||
config := &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
|
||||
accepted, err := p.validateCertificateWithJS(conn, certChain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !accepted {
|
||||
return fmt.Errorf("certificate rejected by user")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||
if requiresCredSSP {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
} else {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
@@ -1,344 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
|
||||
rdpDialTimeout = 15 * time.Second
|
||||
|
||||
GeneralErrorCode = 1
|
||||
WSAETimedOut = 10060
|
||||
WSAEConnRefused = 10061
|
||||
WSAEConnAborted = 10053
|
||||
WSAEConnReset = 10054
|
||||
WSAEGenericError = 10050
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathErr struct {
|
||||
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*proxyConnection
|
||||
destinations map[string]string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type proxyConnection struct {
|
||||
id string
|
||||
destination string
|
||||
rdpConn net.Conn
|
||||
tlsConn *tls.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRDCleanPathProxy creates a new RDCleanPath proxy
|
||||
func NewRDCleanPathProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *RDCleanPathProxy {
|
||||
return &RDCleanPathProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*proxyConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given destination
|
||||
func (p *RDCleanPathProxy) CreateProxy(hostname, port string) js.Value {
|
||||
destination := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("proxy_%d", len(p.activeConnections))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]string)
|
||||
}
|
||||
p.destinations[proxyID] = destination
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", RDCleanPathProxyScheme, RDCleanPathProxyHost, proxyID)
|
||||
|
||||
// Register the WebSocket handler for this specific proxy
|
||||
js.Global().Set(fmt.Sprintf("handleRDCleanPathWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
|
||||
ws := args[0]
|
||||
p.HandleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
|
||||
log.Infof("Created RDCleanPath proxy endpoint: %s for destination: %s", proxyURL, destination)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// HandleWebSocketConnection handles incoming WebSocket connections from IronRDP
|
||||
func (p *RDCleanPathProxy) HandleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
destination := p.destinations[proxyID]
|
||||
p.mu.Unlock()
|
||||
|
||||
if destination == "" {
|
||||
log.Errorf("No destination found for proxy ID: %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Don't defer cancel here - it will be called by cleanupConnection
|
||||
|
||||
conn := &proxyConnection{
|
||||
id: proxyID,
|
||||
destination: destination,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
|
||||
log.Infof("RDCleanPath proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupWebSocketHandlers(ws js.Value, conn *proxyConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Debug("WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleWebSocketMessage(conn *proxyConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
|
||||
if conn.rdpConn != nil || conn.tlsConn != nil {
|
||||
p.forwardToRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
|
||||
var pdu RDCleanPathPDU
|
||||
_, err := asn1.Unmarshal(bytes, &pdu)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse RDCleanPath PDU: %v", err)
|
||||
n := len(bytes)
|
||||
if n > 20 {
|
||||
n = 20
|
||||
}
|
||||
log.Warnf("First %d bytes: %x", n, bytes[:n])
|
||||
|
||||
if len(bytes) > 0 && bytes[0] == 0x03 {
|
||||
log.Debug("Received raw RDP packet instead of RDCleanPath PDU")
|
||||
go p.handleDirectRDP(conn, bytes)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go p.processRDCleanPathPDU(conn, pdu)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardToRDP(conn *proxyConnection, bytes []byte) {
|
||||
var writer io.Writer
|
||||
var connType string
|
||||
|
||||
if conn.tlsConn != nil {
|
||||
writer = conn.tlsConn
|
||||
connType = "TLS"
|
||||
} else if conn.rdpConn != nil {
|
||||
writer = conn.rdpConn
|
||||
connType = "TCP"
|
||||
} else {
|
||||
log.Error("No RDP connection available")
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := writer.Write(bytes); err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []byte) {
|
||||
defer p.cleanupConnection(conn)
|
||||
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, response[:n])
|
||||
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) cleanupConnection(conn *proxyConnection) {
|
||||
log.Debugf("Cleaning up connection %s", conn.id)
|
||||
conn.cancel()
|
||||
if conn.tlsConn != nil {
|
||||
log.Debug("Closing TLS connection")
|
||||
if err := conn.tlsConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TLS connection: %v", err)
|
||||
}
|
||||
conn.tlsConn = nil
|
||||
}
|
||||
if conn.rdpConn != nil {
|
||||
log.Debug("Closing TCP connection")
|
||||
if err := conn.rdpConn.Close(); err != nil {
|
||||
log.Debugf("Error closing TCP connection: %v", err)
|
||||
}
|
||||
conn.rdpConn = nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func errorToWSACode(err error) int16 {
|
||||
if err == nil {
|
||||
return WSAEGenericError
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return WSAEConnAborted
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return WSAEConnReset
|
||||
}
|
||||
return WSAEGenericError
|
||||
}
|
||||
|
||||
func newWSAError(err error) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
WSALastError: errorToWSACode(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
HTTPStatusCode: statusCode,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,244 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
"syscall/js"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||
protocolSSL = 0x00000001
|
||||
protocolHybridEx = 0x00000008
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||
return
|
||||
}
|
||||
|
||||
destination := conn.destination
|
||||
if pdu.Destination != "" {
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
|
||||
// RDP always starts with X.224 negotiation, then determines if TLS is needed
|
||||
// Modern RDP (since Windows Vista/2008) typically requires TLS
|
||||
// The X.224 Connection Confirm response will indicate if TLS is required
|
||||
// For now, we'll attempt TLS for all connections as it's the modern default
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||
const minResponseLength = 19
|
||||
|
||||
if len(x224Response) < minResponseLength {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
// Per X.224 specification:
|
||||
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
if x224Response[11] == 0x02 {
|
||||
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||
|
||||
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||
return hasNLA, flags, true
|
||||
}
|
||||
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||
if detected {
|
||||
if requiresCredSSP {
|
||||
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||
} else {
|
||||
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("TLS handshake successful")
|
||||
|
||||
// Certificate validation happens during handshake via VerifyConnection callback
|
||||
var certChain [][]byte
|
||||
connState := tlsConn.ConnectionState()
|
||||
if len(connState.PeerCertificates) > 0 {
|
||||
for _, cert := range connState.PeerCertificates {
|
||||
certChain = append(certChain, cert.Raw)
|
||||
}
|
||||
log.Debugf("Extracted %d certificates from TLS connection", len(certChain))
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
ServerCertChain: certChain,
|
||||
}
|
||||
|
||||
if len(x224Response) > 0 {
|
||||
responsePDU.X224ConnectionPDU = x224Response
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
|
||||
log.Debug("Starting TLS forwarding")
|
||||
go p.forwardConnToWS(conn, conn.tlsConn, "TLS")
|
||||
go p.forwardWSToConn(conn, conn.tlsConn, "TLS")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TLS connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal RDCleanPath PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending RDCleanPath PDU response (%d bytes)", len(data))
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
handler := js.FuncOf(func(this js.Value, args []js.Value) interface{} {
|
||||
if len(args) < 1 {
|
||||
errChan <- io.EOF
|
||||
return nil
|
||||
}
|
||||
|
||||
data := args[0]
|
||||
if data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
length := data.Get("length").Int()
|
||||
bytes := make([]byte, length)
|
||||
js.CopyBytesToGo(bytes, data)
|
||||
msgChan <- bytes
|
||||
}
|
||||
return nil
|
||||
})
|
||||
defer handler.Release()
|
||||
|
||||
conn.wsHandlers.Set("onceGoMessage", handler)
|
||||
|
||||
select {
|
||||
case msg := <-msgChan:
|
||||
return msg, nil
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-conn.ctx.Done():
|
||||
return nil, conn.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardWSToConn(conn *proxyConnection, dst io.Writer, connType string) {
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := p.readWebSocketMessage(conn)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from WebSocket: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
_, err = dst.Write(msg)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write to %s: %v", connType, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) forwardConnToWS(conn *proxyConnection, src io.Reader, connType string) {
|
||||
buffer := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := src.Read(buffer)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Errorf("Failed to read from %s: %v", connType, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buffer[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
332
client/wasm/internal/vnc/proxy.go
Normal file
332
client/wasm/internal/vnc/proxy.go
Normal file
@@ -0,0 +1,332 @@
|
||||
//go:build js
|
||||
|
||||
package vnc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
vncProxyHost = "vnc.proxy.local"
|
||||
vncProxyScheme = "ws"
|
||||
vncDialTimeout = 15 * time.Second
|
||||
|
||||
// Connection modes matching server/server.go constants.
|
||||
modeAttach byte = 0
|
||||
modeSession byte = 1
|
||||
)
|
||||
|
||||
// VNCProxy bridges WebSocket connections from noVNC in the browser
|
||||
// to TCP VNC server connections through the NetBird tunnel.
|
||||
type VNCProxy struct {
|
||||
nbClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
activeConnections map[string]*vncConnection
|
||||
destinations map[string]vncDestination
|
||||
mu sync.Mutex
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
type vncDestination struct {
|
||||
address string
|
||||
mode byte
|
||||
username string
|
||||
jwt string
|
||||
sessionID uint32 // Windows session ID (0 = auto/console)
|
||||
}
|
||||
|
||||
type vncConnection struct {
|
||||
id string
|
||||
destination vncDestination
|
||||
mu sync.Mutex
|
||||
vncConn net.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewVNCProxy creates a new VNC proxy.
|
||||
func NewVNCProxy(client interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}) *VNCProxy {
|
||||
return &VNCProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*vncConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given VNC destination.
|
||||
// mode is "attach" (capture current display) or "session" (virtual session).
|
||||
// username is required for session mode.
|
||||
// Returns a JS Promise that resolves to the WebSocket proxy URL.
|
||||
func (p *VNCProxy) CreateProxy(hostname, port, mode, username, jwt string, sessionID uint32) js.Value {
|
||||
address := fmt.Sprintf("%s:%s", hostname, port)
|
||||
|
||||
var m byte
|
||||
if mode == "session" {
|
||||
m = modeSession
|
||||
}
|
||||
|
||||
dest := vncDestination{
|
||||
address: address,
|
||||
mode: m,
|
||||
username: username,
|
||||
jwt: jwt,
|
||||
sessionID: sessionID,
|
||||
}
|
||||
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
proxyID := fmt.Sprintf("vnc_proxy_%d", p.nextID.Add(1))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]vncDestination)
|
||||
}
|
||||
p.destinations[proxyID] = dest
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", vncProxyScheme, vncProxyHost, proxyID)
|
||||
|
||||
js.Global().Set(fmt.Sprintf("handleVNCWebSocket_%s", proxyID), js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
ws := args[0]
|
||||
p.handleWebSocketConnection(ws, proxyID)
|
||||
return nil
|
||||
}))
|
||||
|
||||
log.Infof("created VNC proxy: %s -> %s (mode=%s, user=%s)", proxyURL, address, mode, username)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
dest, ok := p.destinations[proxyID]
|
||||
p.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
log.Errorf("no destination for VNC proxy %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
conn := &vncConnection{
|
||||
id: proxyID,
|
||||
destination: dest,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
go p.connectToVNC(conn)
|
||||
|
||||
log.Infof("VNC proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) setupWebSocketHandlers(ws js.Value, conn *vncConnection) {
|
||||
ws.Set("onGoMessage", js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
}))
|
||||
|
||||
ws.Set("onGoClose", js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
log.Debug("VNC WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketMessage(conn *vncConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
buf := make([]byte, length)
|
||||
js.CopyBytesToGo(buf, data)
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := vncConn.Write(buf); err != nil {
|
||||
log.Debugf("write to VNC server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) connectToVNC(conn *vncConnection) {
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, vncDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
vncConn, err := p.nbClient.Dial(ctx, "tcp", conn.destination.address)
|
||||
if err != nil {
|
||||
log.Errorf("VNC connect to %s: %v", conn.destination.address, err)
|
||||
// Close the WebSocket so noVNC fires a disconnect event.
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", 1006, fmt.Sprintf("connect to peer: %v", err))
|
||||
}
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.mu.Lock()
|
||||
conn.vncConn = vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
// Send the NetBird VNC session header before the RFB handshake.
|
||||
if err := p.sendSessionHeader(vncConn, conn.destination); err != nil {
|
||||
log.Errorf("send VNC session header: %v", err)
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
// WS→TCP is handled by the onGoMessage handler set in setupWebSocketHandlers,
|
||||
// which writes directly to the VNC connection as data arrives from JS.
|
||||
// Only the TCP→WS direction needs a read loop here.
|
||||
go p.forwardConnToWS(conn)
|
||||
|
||||
<-conn.ctx.Done()
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
// sendSessionHeader writes mode, username, and JWT to the VNC server.
|
||||
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
|
||||
//
|
||||
// [jwt_len: 2 bytes BE] [jwt: N bytes]
|
||||
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
|
||||
usernameBytes := []byte(dest.username)
|
||||
jwtBytes := []byte(dest.jwt)
|
||||
// Format: [mode:1] [username_len:2] [username:N] [jwt_len:2] [jwt:N] [session_id:4]
|
||||
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes)+4)
|
||||
hdr[0] = dest.mode
|
||||
hdr[1] = byte(len(usernameBytes) >> 8)
|
||||
hdr[2] = byte(len(usernameBytes))
|
||||
off := 3
|
||||
copy(hdr[off:], usernameBytes)
|
||||
off += len(usernameBytes)
|
||||
hdr[off] = byte(len(jwtBytes) >> 8)
|
||||
hdr[off+1] = byte(len(jwtBytes))
|
||||
off += 2
|
||||
copy(hdr[off:], jwtBytes)
|
||||
off += len(jwtBytes)
|
||||
hdr[off] = byte(dest.sessionID >> 24)
|
||||
hdr[off+1] = byte(dest.sessionID >> 16)
|
||||
hdr[off+2] = byte(dest.sessionID >> 8)
|
||||
hdr[off+3] = byte(dest.sessionID)
|
||||
|
||||
_, err := conn.Write(hdr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *VNCProxy) forwardConnToWS(conn *vncConnection) {
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Set a read deadline so we detect dead connections instead of
|
||||
// blocking forever when the remote peer dies.
|
||||
conn.mu.Lock()
|
||||
vc := conn.vncConn
|
||||
conn.mu.Unlock()
|
||||
if vc == nil {
|
||||
return
|
||||
}
|
||||
vc.SetReadDeadline(time.Now().Add(30 * time.Second))
|
||||
|
||||
n, err := vc.Read(buf)
|
||||
if err != nil {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
// Read timeout: connection might be stale. Send a ping-like
|
||||
// empty read to check. If the connection is truly dead, the
|
||||
// next iteration will fail too and we'll close.
|
||||
continue
|
||||
}
|
||||
if err != io.EOF {
|
||||
log.Debugf("read from VNC connection: %v", err)
|
||||
}
|
||||
// Close the WebSocket to notify noVNC.
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", 1006, "VNC connection lost")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) sendToWebSocket(conn *vncConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) cleanupConnection(conn *vncConnection) {
|
||||
log.Debugf("cleaning up VNC connection %s", conn.id)
|
||||
conn.cancel()
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.vncConn = nil
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn != nil {
|
||||
if err := vncConn.Close(); err != nil {
|
||||
log.Debugf("close VNC connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the global JS handler registered in CreateProxy.
|
||||
globalName := fmt.Sprintf("handleVNCWebSocket_%s", conn.id)
|
||||
js.Global().Delete(globalName)
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -26,11 +25,22 @@ import (
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
var ErrClientClosed = errors.New("client is closed")
|
||||
|
||||
// minHealthyDuration is the minimum time a stream must survive before a failure
|
||||
// resets the backoff timer. Streams that fail faster are considered unhealthy and
|
||||
// should not reset backoff, so that MaxElapsedTime can eventually stop retries.
|
||||
const minHealthyDuration = 5 * time.Second
|
||||
|
||||
type GRPCClient struct {
|
||||
realClient proto.FlowServiceClient
|
||||
clientConn *grpc.ClientConn
|
||||
stream proto.FlowService_EventsClient
|
||||
streamMu sync.Mutex
|
||||
target string
|
||||
opts []grpc.DialOption
|
||||
closed bool // prevent creating conn in the middle of the Close
|
||||
receiving bool // prevent concurrent Receive calls
|
||||
mu sync.Mutex // protects clientConn, realClient, stream, closed, and receiving
|
||||
}
|
||||
|
||||
func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCClient, error) {
|
||||
@@ -65,7 +75,8 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
grpc.WithDefaultServiceConfig(`{"healthCheckConfig": {"serviceName": ""}}`),
|
||||
)
|
||||
|
||||
conn, err := grpc.NewClient(fmt.Sprintf("%s:%s", parsedURL.Hostname(), parsedURL.Port()), opts...)
|
||||
target := parsedURL.Host
|
||||
conn, err := grpc.NewClient(target, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating new grpc client: %w", err)
|
||||
}
|
||||
@@ -73,30 +84,73 @@ func NewClient(addr, payload, signature string, interval time.Duration) (*GRPCCl
|
||||
return &GRPCClient{
|
||||
realClient: proto.NewFlowServiceClient(conn),
|
||||
clientConn: conn,
|
||||
target: target,
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Close() error {
|
||||
c.streamMu.Lock()
|
||||
defer c.streamMu.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
c.closed = true
|
||||
c.stream = nil
|
||||
if err := c.clientConn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
conn := c.clientConn
|
||||
c.clientConn = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
return fmt.Errorf("close client connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||
c.mu.Lock()
|
||||
stream := c.stream
|
||||
c.mu.Unlock()
|
||||
|
||||
if stream == nil {
|
||||
return errors.New("stream not initialized")
|
||||
}
|
||||
|
||||
if err := stream.Send(event); err != nil {
|
||||
return fmt.Errorf("send flow event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
c.mu.Lock()
|
||||
if c.receiving {
|
||||
c.mu.Unlock()
|
||||
return errors.New("concurrent Receive calls are not supported")
|
||||
}
|
||||
c.receiving = true
|
||||
c.mu.Unlock()
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
c.receiving = false
|
||||
c.mu.Unlock()
|
||||
}()
|
||||
|
||||
backOff := defaultBackoff(ctx, interval)
|
||||
operation := func() error {
|
||||
if err := c.establishStreamAndReceive(ctx, msgHandler); err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return fmt.Errorf("receive: %w: %w", err, context.Canceled)
|
||||
}
|
||||
stream, err := c.establishStream(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to establish flow stream, retrying: %v", err)
|
||||
return c.handleRetryableError(err, time.Time{}, backOff)
|
||||
}
|
||||
|
||||
streamStart := time.Now()
|
||||
|
||||
if err := c.receive(stream, msgHandler); err != nil {
|
||||
log.Errorf("receive failed: %v", err)
|
||||
return fmt.Errorf("receive: %w", err)
|
||||
return c.handleRetryableError(err, streamStart, backOff)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -108,37 +162,106 @@ func (c *GRPCClient) Receive(ctx context.Context, interval time.Duration, msgHan
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) establishStreamAndReceive(ctx context.Context, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
if c.clientConn.GetState() == connectivity.Shutdown {
|
||||
return errors.New("connection to flow receiver has been shut down")
|
||||
// handleRetryableError resets the backoff timer if the stream was healthy long
|
||||
// enough and recreates the underlying ClientConn so that gRPC's internal
|
||||
// subchannel backoff does not accumulate and compete with our own retry timer.
|
||||
// A zero streamStart means the stream was never established.
|
||||
func (c *GRPCClient) handleRetryableError(err error, streamStart time.Time, backOff backoff.BackOff) error {
|
||||
if isContextDone(err) {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
stream, err := c.realClient.Events(ctx, grpc.WaitForReady(true))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create event stream: %w", err)
|
||||
var permErr *backoff.PermanentError
|
||||
if errors.As(err, &permErr) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = stream.Send(&proto.FlowEvent{IsInitiator: true})
|
||||
// Reset the backoff so the next retry starts with a short delay instead of
|
||||
// continuing the already-elapsed timer. Only do this if the stream was healthy
|
||||
// long enough; short-lived connect/drop cycles must not defeat MaxElapsedTime.
|
||||
if !streamStart.IsZero() && time.Since(streamStart) >= minHealthyDuration {
|
||||
backOff.Reset()
|
||||
}
|
||||
|
||||
if recreateErr := c.recreateConnection(); recreateErr != nil {
|
||||
log.Errorf("recreate connection: %v", recreateErr)
|
||||
return recreateErr
|
||||
}
|
||||
|
||||
log.Infof("connection recreated, retrying stream")
|
||||
return fmt.Errorf("retrying after error: %w", err)
|
||||
}
|
||||
|
||||
func (c *GRPCClient) recreateConnection() error {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return backoff.Permanent(ErrClientClosed)
|
||||
}
|
||||
|
||||
conn, err := grpc.NewClient(c.target, c.opts...)
|
||||
if err != nil {
|
||||
log.Infof("failed to send initiator message to flow receiver but will attempt to continue. Error: %s", err)
|
||||
c.mu.Unlock()
|
||||
return fmt.Errorf("create new connection: %w", err)
|
||||
}
|
||||
|
||||
old := c.clientConn
|
||||
c.clientConn = conn
|
||||
c.realClient = proto.NewFlowServiceClient(conn)
|
||||
c.stream = nil
|
||||
c.mu.Unlock()
|
||||
|
||||
_ = old.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) establishStream(ctx context.Context) (proto.FlowService_EventsClient, error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, backoff.Permanent(ErrClientClosed)
|
||||
}
|
||||
cl := c.realClient
|
||||
c.mu.Unlock()
|
||||
|
||||
// open stream outside the lock — blocking operation
|
||||
stream, err := cl.Events(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create event stream: %w", err)
|
||||
}
|
||||
streamReady := false
|
||||
defer func() {
|
||||
if !streamReady {
|
||||
_ = stream.CloseSend()
|
||||
}
|
||||
}()
|
||||
|
||||
if err = stream.Send(&proto.FlowEvent{IsInitiator: true}); err != nil {
|
||||
return nil, fmt.Errorf("send initiator: %w", err)
|
||||
}
|
||||
|
||||
if err = checkHeader(stream); err != nil {
|
||||
return fmt.Errorf("check header: %w", err)
|
||||
return nil, fmt.Errorf("check header: %w", err)
|
||||
}
|
||||
|
||||
c.streamMu.Lock()
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return nil, backoff.Permanent(ErrClientClosed)
|
||||
}
|
||||
c.stream = stream
|
||||
c.streamMu.Unlock()
|
||||
c.mu.Unlock()
|
||||
streamReady = true
|
||||
|
||||
return c.receive(stream, msgHandler)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *GRPCClient) receive(stream proto.FlowService_EventsClient, msgHandler func(msg *proto.FlowEventAck) error) error {
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive from stream: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if msg.IsInitiator {
|
||||
@@ -169,7 +292,7 @@ func checkHeader(stream proto.FlowService_EventsClient) error {
|
||||
func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 1,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.7,
|
||||
MaxInterval: interval / 2,
|
||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||
@@ -178,18 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *GRPCClient) Send(event *proto.FlowEvent) error {
|
||||
c.streamMu.Lock()
|
||||
stream := c.stream
|
||||
c.streamMu.Unlock()
|
||||
|
||||
if stream == nil {
|
||||
return errors.New("stream not initialized")
|
||||
func isContextDone(err error) bool {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
if err := stream.Send(event); err != nil {
|
||||
return fmt.Errorf("send flow event: %w", err)
|
||||
if s, ok := status.FromError(err); ok {
|
||||
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -11,6 +14,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
flow "github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
@@ -18,21 +23,89 @@ import (
|
||||
|
||||
type testServer struct {
|
||||
proto.UnimplementedFlowServiceServer
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
listener *connTrackListener
|
||||
closeStream chan struct{} // signal server to close the stream
|
||||
handlerDone chan struct{} // signaled each time Events() exits
|
||||
handlerStarted chan struct{} // signaled each time Events() begins
|
||||
}
|
||||
|
||||
// connTrackListener wraps a net.Listener to track accepted connections
|
||||
// so tests can forcefully close them to simulate PROTOCOL_ERROR/RST_STREAM.
|
||||
type connTrackListener struct {
|
||||
net.Listener
|
||||
mu sync.Mutex
|
||||
conns []net.Conn
|
||||
}
|
||||
|
||||
func (l *connTrackListener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.conns = append(l.conns, c)
|
||||
l.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// sendRSTStream writes a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR
|
||||
// (error code 0x1) on every tracked connection. This produces the exact error:
|
||||
//
|
||||
// rpc error: code = Internal desc = stream terminated by RST_STREAM with error code: PROTOCOL_ERROR
|
||||
//
|
||||
// HTTP/2 RST_STREAM frame format (9-byte header + 4-byte payload):
|
||||
//
|
||||
// Length (3 bytes): 0x000004
|
||||
// Type (1 byte): 0x03 (RST_STREAM)
|
||||
// Flags (1 byte): 0x00
|
||||
// Stream ID (4 bytes): target stream (must have bit 31 clear)
|
||||
// Error Code (4 bytes): 0x00000001 (PROTOCOL_ERROR)
|
||||
func (l *connTrackListener) connCount() int {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return len(l.conns)
|
||||
}
|
||||
|
||||
func (l *connTrackListener) sendRSTStream(streamID uint32) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
frame := make([]byte, 13) // 9-byte header + 4-byte payload
|
||||
// Length = 4 (3 bytes, big-endian)
|
||||
frame[0], frame[1], frame[2] = 0, 0, 4
|
||||
// Type = RST_STREAM (0x03)
|
||||
frame[3] = 0x03
|
||||
// Flags = 0
|
||||
frame[4] = 0x00
|
||||
// Stream ID (4 bytes, big-endian, bit 31 reserved = 0)
|
||||
binary.BigEndian.PutUint32(frame[5:9], streamID)
|
||||
// Error Code = PROTOCOL_ERROR (0x1)
|
||||
binary.BigEndian.PutUint32(frame[9:13], 0x1)
|
||||
|
||||
for _, c := range l.conns {
|
||||
_, _ = c.Write(frame)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *testServer {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
rawListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
listener := &connTrackListener{Listener: rawListener}
|
||||
|
||||
s := &testServer{
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: listener.Addr().String(),
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: rawListener.Addr().String(),
|
||||
listener: listener,
|
||||
closeStream: make(chan struct{}, 1),
|
||||
handlerDone: make(chan struct{}, 10),
|
||||
handlerStarted: make(chan struct{}, 10),
|
||||
}
|
||||
|
||||
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||
@@ -51,11 +124,23 @@ func newTestServer(t *testing.T) *testServer {
|
||||
}
|
||||
|
||||
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
defer func() {
|
||||
select {
|
||||
case s.handlerDone <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}()
|
||||
|
||||
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case s.handlerStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(stream.Context())
|
||||
defer cancel()
|
||||
|
||||
@@ -91,6 +176,8 @@ func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
if err := stream.Send(ack); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-s.closeStream:
|
||||
return status.Errorf(codes.Internal, "server closing stream")
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
@@ -110,16 +197,13 @@ func TestReceive(t *testing.T) {
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
receivedAcks := make(map[string]bool)
|
||||
var ackCount atomic.Int32
|
||||
receiveDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||
id := string(msg.EventId)
|
||||
receivedAcks[id] = true
|
||||
|
||||
if len(receivedAcks) >= 3 {
|
||||
if ackCount.Add(1) >= 3 {
|
||||
close(receiveDone)
|
||||
}
|
||||
}
|
||||
@@ -130,7 +214,11 @@ func TestReceive(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
select {
|
||||
case <-server.handlerStarted:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for stream to be established")
|
||||
}
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
eventID := uuid.New().String()
|
||||
@@ -153,7 +241,7 @@ func TestReceive(t *testing.T) {
|
||||
t.Fatal("timeout waiting for acks to be processed")
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(receivedAcks))
|
||||
assert.Equal(t, int32(3), ackCount.Load())
|
||||
}
|
||||
|
||||
func TestReceive_ContextCancellation(t *testing.T) {
|
||||
@@ -254,3 +342,195 @@ func TestSend(t *testing.T) {
|
||||
t.Fatal("timeout waiting for ack to be received by flow")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_PermanentClose(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = client.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.ErrorIs(t, err, flow.ErrClientClosed)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive did not return after Close — stuck in retry loop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClient_CloseVerify(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
|
||||
closeDone := make(chan struct{}, 1)
|
||||
go func() {
|
||||
_ = client.Close()
|
||||
closeDone <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.Error(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive did not return after Close — stuck in retry loop")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-closeDone:
|
||||
return
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Close did not return — blocked in retry loop")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestClose_WhileReceiving(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background() // no timeout — intentional
|
||||
receiveDone := make(chan struct{})
|
||||
go func() {
|
||||
_ = client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
return nil
|
||||
})
|
||||
close(receiveDone)
|
||||
}()
|
||||
|
||||
// Wait for the server-side handler to confirm the stream is established.
|
||||
select {
|
||||
case <-server.handlerStarted:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for stream to be established")
|
||||
}
|
||||
|
||||
closeDone := make(chan struct{})
|
||||
go func() {
|
||||
_ = client.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-closeDone:
|
||||
// Close returned — good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Close blocked forever — Receive stuck in retry loop")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Receive did not exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
// Track acks received before and after server-side stream close
|
||||
var ackCount atomic.Int32
|
||||
receivedFirst := make(chan struct{})
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
}
|
||||
n := ackCount.Add(1)
|
||||
if n == 1 {
|
||||
close(receivedFirst)
|
||||
}
|
||||
if n == 2 {
|
||||
close(receivedAfterReconnect)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Wait for stream to be established, then send first ack
|
||||
select {
|
||||
case <-server.handlerStarted:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for stream to be established")
|
||||
}
|
||||
server.acks <- &proto.FlowEventAck{EventId: []byte("before-close")}
|
||||
|
||||
select {
|
||||
case <-receivedFirst:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for first ack")
|
||||
}
|
||||
|
||||
// Snapshot connection count before injecting the fault.
|
||||
connsBefore := server.listener.connCount()
|
||||
|
||||
// Send a raw HTTP/2 RST_STREAM frame with PROTOCOL_ERROR on the TCP connection.
|
||||
// gRPC multiplexes streams on stream IDs 1, 3, 5, ... (odd, client-initiated).
|
||||
// Stream ID 1 is the client's first stream (our Events bidi stream).
|
||||
// This produces the exact error the client sees in production:
|
||||
// "stream terminated by RST_STREAM with error code: PROTOCOL_ERROR"
|
||||
server.listener.sendRSTStream(1)
|
||||
|
||||
// Wait for the old Events() handler to fully exit so it can no longer
|
||||
// drain s.acks and drop our injected ack on a broken stream.
|
||||
select {
|
||||
case <-server.handlerDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("old Events() handler did not exit after RST_STREAM")
|
||||
}
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return server.listener.connCount() > connsBefore
|
||||
}, 5*time.Second, 50*time.Millisecond, "client did not open a new TCP connection after RST_STREAM")
|
||||
|
||||
server.acks <- &proto.FlowEventAck{EventId: []byte("after-close")}
|
||||
|
||||
select {
|
||||
case <-receivedAfterReconnect:
|
||||
// Client successfully reconnected and received ack after server-side stream close
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for ack after server-side stream close — client did not reconnect")
|
||||
}
|
||||
|
||||
assert.GreaterOrEqual(t, int(ackCount.Load()), 2, "should have received acks before and after stream close")
|
||||
assert.GreaterOrEqual(t, server.listener.connCount(), 2, "client should have created at least 2 TCP connections (original + reconnect)")
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -208,12 +208,14 @@ require (
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade // indirect
|
||||
github.com/jezek/xgb v1.3.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/jonboulle/clockwork v0.5.0 // indirect
|
||||
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 // indirect
|
||||
github.com/kelseyhightower/envconfig v1.4.0 // indirect
|
||||
github.com/kirides/go-d3d v1.0.1 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/koron/go-ssdp v0.0.4 // indirect
|
||||
|
||||
4
go.sum
4
go.sum
@@ -309,6 +309,8 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
|
||||
github.com/jezek/xgb v1.3.0 h1:Wa1pn4GVtcmNVAVB6/pnQVJ7xPFZVZ/W1Tc27msDhgI=
|
||||
github.com/jezek/xgb v1.3.0/go.mod h1:nrhwO0FX/enq75I7Y7G8iN1ubpSGZEiA3v9e9GyRFlk=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -325,6 +327,8 @@ github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6U
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
|
||||
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/kirides/go-d3d v1.0.1 h1:ZDANfvo34vskBMET1uwUUMNw8545Kbe8qYSiRwlNIuA=
|
||||
github.com/kirides/go-d3d v1.0.1/go.mod h1:99AjD+5mRTFEnkpRWkwq8UYMQDljGIIvLn2NyRdVImY=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
|
||||
@@ -94,10 +94,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
|
||||
sshConfig := &proto.SSHConfig{
|
||||
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||
}
|
||||
|
||||
if sshConfig.SshEnabled {
|
||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||
JwtConfig: buildJWTConfig(httpConfig, deviceFlowConfig),
|
||||
}
|
||||
|
||||
return &proto.PeerConfig{
|
||||
@@ -156,19 +153,21 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
if networkMap.TransparentProxyConfig != nil {
|
||||
response.NetworkMap.TransparentProxyConfig = networkMap.TransparentProxyConfig.ToProto()
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
if networkMap.VNCAuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.VNCAuthorizedUsers)
|
||||
response.NetworkMap.VncAuth = &proto.VNCAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
|
||||
@@ -665,6 +665,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
|
||||
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
|
||||
ServerVNCAllowed: meta.GetFlags().GetServerVNCAllowed(),
|
||||
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
|
||||
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
|
||||
DisableDNS: meta.GetFlags().GetDisableDNS(),
|
||||
@@ -672,6 +673,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
BlockLANAccess: meta.GetFlags().GetBlockLANAccess(),
|
||||
BlockInbound: meta.GetFlags().GetBlockInbound(),
|
||||
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
|
||||
DisableVNCAuth: meta.GetFlags().GetDisableVNCAuth(),
|
||||
},
|
||||
Files: files,
|
||||
}
|
||||
|
||||
@@ -113,10 +113,6 @@ type Manager interface {
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error)
|
||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetInspectionPolicy(ctx context.Context, accountID, policyID, userID string) (*types.InspectionPolicy, error)
|
||||
SaveInspectionPolicy(ctx context.Context, accountID, userID string, policy *types.InspectionPolicy, create bool) (*types.InspectionPolicy, error)
|
||||
DeleteInspectionPolicy(ctx context.Context, accountID, policyID, userID string) error
|
||||
ListInspectionPolicies(ctx context.Context, accountID, userID string) ([]*types.InspectionPolicy, error)
|
||||
GetIdpManager() idp.Manager
|
||||
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
|
||||
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user