[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -19,35 +19,37 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
- name: Check for problematic license dependencies
run: |
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
echo ""
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ ! -z "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
# Find all directories except the problematic ones and system dirs
FOUND_ISSUES=0
while IFS= read -r dir; do
echo "=== Checking $dir ==="
# Search for problematic imports, excluding test files
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
if [ -n "$RESULTS" ]; then
echo "❌ Found problematic dependencies:"
echo "$RESULTS"
FOUND_ISSUES=1
else
echo "✓ No problematic dependencies found"
fi
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
echo ""
if [ $FOUND_ISSUES -eq 1 ]; then
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo "✓ No problematic dependencies found"
echo ""
echo "✅ All internal license dependencies are clean"
fi
done
if [ $FOUND_ISSUES -eq 1 ]; then
echo ""
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
exit 1
else
echo ""
echo "✅ All internal license dependencies are clean"
fi
check-external-licenses:
name: Check External GPL/AGPL Licenses

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net"
)
// ConnectionListener export internal Listener for mobile

View File

@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetEnableSSHRoot reads SSH root login setting from config file
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
if p.configInput.EnableSSHRoot != nil {
return *p.configInput.EnableSSHRoot, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHRoot == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHRoot, err
}
// SetEnableSSHRoot stores the given value and waits for commit
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
p.configInput.EnableSSHRoot = &enabled
}
// GetEnableSSHSFTP reads SSH SFTP setting from config file
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
if p.configInput.EnableSSHSFTP != nil {
return *p.configInput.EnableSSHSFTP, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHSFTP == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHSFTP, err
}
// SetEnableSSHSFTP stores the given value and waits for commit
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
p.configInput.EnableSSHSFTP = &enabled
}
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
if p.configInput.EnableSSHLocalPortForwarding != nil {
return *p.configInput.EnableSSHLocalPortForwarding, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHLocalPortForwarding == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHLocalPortForwarding, err
}
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
p.configInput.EnableSSHLocalPortForwarding = &enabled
}
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
if p.configInput.EnableSSHRemotePortForwarding != nil {
return *p.configInput.EnableSSHRemotePortForwarding, nil
}
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.EnableSSHRemotePortForwarding == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.EnableSSHRemotePortForwarding, err
}
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
p.configInput.EnableSSHRemotePortForwarding = &enabled
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {

View File

@@ -35,7 +35,6 @@ const (
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection"
@@ -64,7 +63,6 @@ var (
customDNSAddress string
rosenpassEnabled bool
rosenpassPermissive bool
serverSSHAllowed bool
interfaceName string
wireguardPort uint16
networkMonitor bool
@@ -176,7 +174,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")

View File

@@ -3,125 +3,809 @@ package cmd
import (
"context"
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
var (
port int
userName = "root"
host string
const (
sshUsernameDesc = "SSH username"
hostArgumentRequired = "host argument required"
serverSSHAllowedFlag = "allow-server-ssh"
enableSSHRootFlag = "enable-ssh-root"
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
var (
port int
username string
host string
command string
localForwards []string
remoteForwards []string
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
requestPTY bool
)
split := strings.Split(args[0], "@")
if len(split) == 2 {
userName = split[0]
host = split[1]
} else {
host = args[0]
}
var (
serverSSHAllowed bool
enableSSHRoot bool
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
)
return nil
},
Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
func init() {
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
cmd.SetOut(cmd.OutOrStdout())
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
}
ctx := internal.CtxInitState(cmd.Context())
sm := profilemanager.NewServiceManager(configPath)
activeProf, err := sm.GetActiveProfileState()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
profPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile path: %v", err)
}
config, err := profilemanager.ReadConfig(profPath)
if err != nil {
return fmt.Errorf("read profile config: %v", err)
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
os.Exit(1)
}
cancel()
}()
select {
case <-sig:
cancel()
case <-sshctx.Done():
}
return nil
},
sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
var sshCmd = &cobra.Command{
Use: "ssh [flags] [user@]host [command]",
Short: "Connect to a NetBird peer via SSH",
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
Port Forwarding:
-L [bind_address:]port:host:hostport Local port forwarding
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
-R [bind_address:]port:host:hostport Remote port forwarding
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
SSH Options:
-p, --port int Remote SSH port (default 22)
-u, --user string SSH username
--login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file
Examples:
netbird ssh peer-hostname
netbird ssh root@peer-hostname
netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true,
Args: validateSSHArgsWithoutFlagParsing,
RunE: sshFn,
Aliases: []string{"ssh"},
}
func sshFn(cmd *cobra.Command, args []string) error {
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return cmd.Help()
}
}
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err
}
cancel()
}()
err = c.OpenTerminal()
if err != nil {
select {
case <-sig:
cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done():
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
func getEnvOrDefault(flagName, defaultValue string) string {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
return envValue
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
return envValue
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() {
port = sshserver.DefaultSSHPort
username = ""
host = ""
command = ""
localForwards = nil
remoteForwards = nil
strictHostKeyChecking = true
knownHostsFile = ""
identityFile = ""
}
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
var localForwardFlags []string
var remoteForwardFlags []string
var filteredArgs []string
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case strings.HasPrefix(arg, "-L"):
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
case strings.HasPrefix(arg, "-R"):
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
default:
filteredArgs = append(filteredArgs, arg)
}
}
return filteredArgs, localForwardFlags, remoteForwardFlags
}
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
if arg == "-L" || arg == "-R" {
if i+1 < len(args) {
flags = append(flags, args[i+1])
i++
}
} else if len(arg) > 2 {
flags = append(flags, arg[2:])
}
return flags, i
}
// extractGlobalFlags parses global flags that were passed before 'ssh' command
func extractGlobalFlags(args []string) {
sshPos := findSSHCommandPosition(args)
if sshPos == -1 {
return
}
globalArgs := args[:sshPos]
parseGlobalArgs(globalArgs)
}
// findSSHCommandPosition locates the 'ssh' command in the argument list
func findSSHCommandPosition(args []string) int {
for i, arg := range args {
if arg == "ssh" {
return i
}
}
return -1
}
const (
configFlag = "config"
logLevelFlag = "log-level"
logFileFlag = "log-file"
)
// parseGlobalArgs processes the global arguments and sets the corresponding variables
func parseGlobalArgs(globalArgs []string) {
flagHandlers := map[string]func(string){
configFlag: func(value string) { configPath = value },
logLevelFlag: func(value string) { logLevel = value },
logFileFlag: func(value string) {
if !slices.Contains(logFiles, value) {
logFiles = append(logFiles, value)
}
},
}
shortFlags := map[string]string{
"c": configFlag,
"l": logLevelFlag,
}
for i := 0; i < len(globalArgs); i++ {
arg := globalArgs[i]
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
i = nextIndex
}
}
}
// parseFlag handles generic flag parsing for both long and short forms
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex
}
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex + 1
}
return false, currentIndex
}
type parsedFlag struct {
flagName string
value string
}
// parseEqualsFormat handles --flag=value and -f=value formats
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if !strings.Contains(arg, "=") {
return parsedFlag{}, false
}
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return parsedFlag{}, false
}
if strings.HasPrefix(parts[0], "--") {
flagName := strings.TrimPrefix(parts[0], "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: parts[1]}, true
}
}
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
shortFlag := strings.TrimPrefix(parts[0], "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: parts[1]}, true
}
}
}
return parsedFlag{}, false
}
// parseSpacedFormat handles --flag value and -f value formats
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if currentIndex+1 >= len(args) {
return parsedFlag{}, false
}
if strings.HasPrefix(arg, "--") {
flagName := strings.TrimPrefix(arg, "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
}
}
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
shortFlag := strings.TrimPrefix(arg, "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
}
}
}
return parsedFlag{}, false
}
// createSSHFlagSet creates and configures the flag set for SSH command parsing
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
RequestPTY bool
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags
}
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
resetSSHGlobals()
if len(os.Args) > 2 {
extractGlobalFlags(os.Args[1:])
}
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err
}
remaining := fs.Args()
if len(remaining) < 1 {
return errors.New(hostArgumentRequired)
}
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}
localForwards = localForwardFlags
remoteForwards = remoteForwardFlags
return parseHostnameAndCommand(remaining)
}
func parseHostnameAndCommand(args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
arg := args[0]
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if username == "" {
username = parts[0]
}
host = parts[1]
} else {
host = arg
}
if username == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
username = sudoUser
} else if currentUser, err := user.Current(); err == nil {
username = currentUser.Username
} else {
username = "root"
}
}
// Everything after hostname becomes the command
if len(args) > 1 {
command = strings.Join(args[1:], " ")
}
return nil
}
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
})
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
return fmt.Errorf("dial %s: %w", target, err)
}
sshCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-sshCtx.Done()
if err := c.Close(); err != nil {
cmd.Printf("Error closing SSH connection: %v\n", err)
}
}()
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
return fmt.Errorf("start port forwarding: %w", err)
}
if command != "" {
return executeSSHCommand(sshCtx, c, command)
}
return openSSHTerminal(sshCtx, c)
}
// executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err)
}
return nil
}
// openSSHTerminal opens an interactive SSH terminal.
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if err := c.OpenTerminal(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err)
}
return nil
}
// startPortForwarding starts local and remote port forwarding based on command line flags
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
for _, forward := range localForwards {
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("local port forward %s: %w", forward, err)
}
}
for _, forward := range remoteForwards {
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("remote port forward %s: %w", forward, err)
}
}
return nil
}
// parseAndStartLocalForward parses and starts a local port forward (-L)
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Local port forward error: %v\n", err)
}
}()
return nil
}
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Remote port forward error: %v\n", err)
}
}()
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {
// Support formats:
// port:host:hostport -> localhost:port -> host:hostport
// host:port:host:hostport -> host:port -> host:hostport
// [host]:port:host:hostport -> [host]:port -> host:hostport
// port:unix_socket_path -> localhost:port -> unix_socket_path
// host:port:unix_socket_path -> host:port -> unix_socket_path
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
return parseIPv6ForwardSpec(spec)
}
parts := strings.Split(spec, ":")
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
}
switch len(parts) {
case 2:
return parseTwoPartForwardSpec(parts, spec)
case 3:
return parseThreePartForwardSpec(parts)
case 4:
return parseFourPartForwardSpec(parts)
default:
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
}
}
// parseTwoPartForwardSpec handles "port:unix_socket" format.
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
if isUnixSocket(parts[1]) {
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1]
return localAddr, remoteAddr, nil
}
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
}
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
func parseThreePartForwardSpec(parts []string) (string, string, error) {
if isUnixSocket(parts[2]) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2]
return localAddr, remoteAddr, nil
}
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
func parseFourPartForwardSpec(parts []string) (string, string, error) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2] + ":" + parts[3]
return localAddr, remoteAddr, nil
}
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
func parseIPv6ForwardSpec(spec string) (string, string, error) {
idx := strings.Index(spec, "]:")
if idx == -1 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
}
ipv6Host := spec[:idx+1]
remaining := spec[idx+2:]
parts := strings.Split(remaining, ":")
if len(parts) != 3 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
}
localAddr := ipv6Host + ":" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// isUnixSocket checks if a path is a Unix socket path.
func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
}
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
func normalizeLocalHost(host string) string {
if host == "*" {
return "0.0.0.0"
}
return host
}
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -0,0 +1,74 @@
//go:build unix
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sshExecUID uint32
sshExecGID uint32
sshExecGroups []uint
sshExecWorkingDir string
sshExecShell string
sshExecCommand string
sshExecPTY bool
)
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
var sshExecCmd = &cobra.Command{
Use: "exec",
Short: "Internal SSH execution with privilege dropping (hidden)",
Hidden: true,
RunE: runSSHExec,
}
func init() {
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
os.Exit(1)
}
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
os.Exit(1)
}
sshCmd.AddCommand(sshExecCmd)
}
// runSSHExec handles the SSH exec subcommand execution.
func runSSHExec(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sshExecGroups {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sshExecUID,
GID: sshExecGID,
Groups: groups,
WorkingDir: sshExecWorkingDir,
Shell: sshExecShell,
Command: sshExecCommand,
PTY: sshExecPTY,
}
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build unix
package cmd
import (
"errors"
"io"
"os"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpUID uint32
sftpGID uint32
sftpGroupsInt []uint
sftpWorkingDir string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with privilege dropping (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sftpGroupsInt {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sftpUID,
GID: sftpGID,
Groups: groups,
WorkingDir: sftpWorkingDir,
Shell: "",
Command: "",
}
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
cmd.PrintErrf("privilege drop failed: %v\n", err)
os.Exit(sshserver.ExitCodePrivilegeDropFail)
}
if config.WorkingDir != "" {
if err := os.Chdir(config.WorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Tracef("starting SFTP server with dropped privileges")
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeShellExecFail)
}
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeSuccess)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build windows
package cmd
import (
"errors"
"fmt"
"io"
"os"
"os/user"
"strings"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpWorkingDir string
windowsUsername string
windowsDomain string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with user switching for Windows (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
return sftpMainDirect(cmd)
}
func sftpMainDirect(cmd *cobra.Command) error {
currentUser, err := user.Current()
if err != nil {
cmd.PrintErrf("failed to get current user: %v\n", err)
os.Exit(sshserver.ExitCodeValidationFail)
}
if windowsUsername != "" {
expectedUsername := windowsUsername
if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
}
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail)
}
}
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
if sftpWorkingDir != "" {
if err := os.Chdir(sftpWorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Debugf("starting SFTP server")
exitCode := sshserver.ExitCodeSuccess
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
exitCode = sshserver.ExitCodeShellExecFail
}
if err := sftpServer.Close(); err != nil {
log.Debugf("SFTP server close error: %v", err)
}
os.Exit(exitCode)
return nil
}

717
client/cmd/ssh_test.go Normal file
View File

@@ -0,0 +1,717 @@
package cmd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSHCommand_FlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedUser string
expectedPort int
expectedCmd string
expectError bool
}{
{
name: "basic host",
args: []string{"hostname"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "",
},
{
name: "user@host format",
args: []string{"user@hostname"},
expectedHost: "hostname",
expectedUser: "user",
expectedPort: 22,
expectedCmd: "",
},
{
name: "host with command",
args: []string{"hostname", "echo", "hello"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "echo hello",
},
{
name: "command with flags should be preserved",
args: []string{"hostname", "ls", "-la", "/tmp"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "ls -la /tmp",
},
{
name: "double dash separator",
args: []string{"hostname", "--", "ls", "-la"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "-- ls -la",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
// Mock command for testing
cmd := sshCmd
cmd.SetArgs(tt.args)
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
if tt.expectedUser != "" {
assert.Equal(t, tt.expectedUser, username, "username mismatch")
}
assert.Equal(t, tt.expectedPort, port, "port mismatch")
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
})
}
}
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
// Test that SSH flags don't conflict with command flags
tests := []struct {
name string
args []string
expectedCmd string
description string
}{
{
name: "ls with -la flags",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
description: "ls flags should be passed to remote command",
},
{
name: "grep with -r flag",
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
expectedCmd: "grep -r pattern /path",
description: "grep flags should be passed to remote command",
},
{
name: "ps with aux flags",
args: []string{"hostname", "ps", "aux"},
expectedCmd: "ps aux",
description: "ps flags should be passed to remote command",
},
{
name: "command with double dash",
args: []string{"hostname", "--", "ls", "-la"},
expectedCmd: "-- ls -la",
description: "double dash should be preserved in command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
// Test that commands with arguments should execute the command and exit,
// not drop to an interactive shell
tests := []struct {
name string
args []string
expectedCmd string
shouldExit bool
description string
}{
{
name: "ls command should execute and exit",
args: []string{"hostname", "ls"},
expectedCmd: "ls",
shouldExit: true,
description: "ls command should execute and exit, not drop to shell",
},
{
name: "ls with flags should execute and exit",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
shouldExit: true,
description: "ls with flags should execute and exit, not drop to shell",
},
{
name: "pwd command should execute and exit",
args: []string{"hostname", "pwd"},
expectedCmd: "pwd",
shouldExit: true,
description: "pwd command should execute and exit, not drop to shell",
},
{
name: "echo command should execute and exit",
args: []string{"hostname", "echo", "hello"},
expectedCmd: "echo hello",
shouldExit: true,
description: "echo command should execute and exit, not drop to shell",
},
{
name: "no command should open shell",
args: []string{"hostname"},
expectedCmd: "",
shouldExit: false,
description: "no command should open interactive shell",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// When command is present, it should execute the command and exit
// When command is empty, it should open interactive shell
hasCommand := command != ""
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
})
}
}
func TestSSHCommand_FlagHandling(t *testing.T) {
// Test that flags after hostname are not parsed by netbird but passed to SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "ls with -la flag should not be parsed by netbird",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
},
{
name: "command with netbird-like flags should be passed through",
args: []string{"hostname", "echo", "--help"},
expectedHost: "hostname",
expectedCmd: "echo --help",
expectError: false,
description: "--help should be passed to echo, not parsed by netbird",
},
{
name: "command with -p flag should not conflict with SSH port flag",
args: []string{"hostname", "ps", "-p", "1234"},
expectedHost: "hostname",
expectedCmd: "ps -p 1234",
expectError: false,
description: "ps -p should be passed to ps command, not parsed as port",
},
{
name: "tar with flags should be passed through",
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
expectedHost: "hostname",
expectedCmd: "tar -czf backup.tar.gz /home",
expectError: false,
description: "tar flags should be passed to tar command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
// should not parse -la as netbird flags but pass them to the SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "original issue: ls -la should be preserved",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "The original failing case should now work",
},
{
name: "ls -l should be preserved",
args: []string{"hostname", "ls", "-l"},
expectedHost: "hostname",
expectedCmd: "ls -l",
expectError: false,
description: "Single letter flags should be preserved",
},
{
name: "SSH port flag should work",
args: []string{"-p", "2222", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedCmd: "ls -la",
expectError: false,
description: "SSH -p flag should be parsed, command flags preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// Check port for the test case with -p flag
if len(tt.args) > 0 && tt.args[0] == "-p" {
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
}
})
}
}
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectError bool
description string
}{
{
name: "local port forwarding -L",
args: []string{"-L", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Single -L flag should be parsed correctly",
},
{
name: "remote port forwarding -R",
args: []string{"-R", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectError: false,
description: "Single -R flag should be parsed correctly",
},
{
name: "multiple local port forwards",
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
expectedRemote: []string{},
expectError: false,
description: "Multiple -L flags should be parsed correctly",
},
{
name: "multiple remote port forwards",
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
expectError: false,
description: "Multiple -R flags should be parsed correctly",
},
{
name: "mixed local and remote forwards",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectError: false,
description: "Mixed -L and -R flags should be parsed correctly",
},
{
name: "port forwarding with bind address",
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with bind address should work",
},
{
name: "port forwarding with command",
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with command should work",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
})
}
}
func TestParsePortForward(t *testing.T) {
tests := []struct {
name string
spec string
expectedLocal string
expectedRemote string
expectError bool
description string
}{
{
name: "simple port forward",
spec: "8080:localhost:80",
expectedLocal: "localhost:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Simple port:host:port format should work",
},
{
name: "port forward with bind address",
spec: "127.0.0.1:8080:localhost:80",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "bind_address:port:host:port format should work",
},
{
name: "port forward to different host",
spec: "8080:example.com:443",
expectedLocal: "localhost:8080",
expectedRemote: "example.com:443",
expectError: false,
description: "Forwarding to different host should work",
},
{
name: "port forward with IPv6 (needs bracket support)",
spec: "::1:8080:localhost:80",
expectError: true,
description: "IPv6 without brackets fails as expected (feature to implement)",
},
{
name: "invalid format - too few parts",
spec: "8080:localhost",
expectError: true,
description: "Invalid format with too few parts should fail",
},
{
name: "invalid format - too many parts",
spec: "127.0.0.1:8080:localhost:80:extra",
expectError: true,
description: "Invalid format with too many parts should fail",
},
{
name: "empty spec",
spec: "",
expectError: true,
description: "Empty spec should fail",
},
{
name: "unix socket local forward",
spec: "8080:/tmp/socket",
expectedLocal: "localhost:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket forwarding should work",
},
{
name: "unix socket with bind address",
spec: "127.0.0.1:8080:/tmp/socket",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket with bind address should work",
},
{
name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
},
{
name: "wildcard for port only",
spec: "8080:*:80",
expectedLocal: "localhost:8080",
expectedRemote: "*:80",
expectError: false,
description: "Wildcard in remote host should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
if tt.expectError {
assert.Error(t, err, tt.description)
return
}
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
})
}
}
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
// Integration test for port forwarding with the actual SSH command implementation
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectedCmd string
description string
}{
{
name: "local forward with command",
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectedCmd: "echo test",
description: "Local forwarding should work with commands",
},
{
name: "remote forward with command",
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectedCmd: "ls -la",
description: "Remote forwarding should work with commands",
},
{
name: "multiple forwards with user and command",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectedCmd: "ps aux",
description: "Complex case with multiple forwards, user, and command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
})
}
}
func TestSSHCommand_ParameterIsolation(t *testing.T) {
tests := []struct {
name string
args []string
expectedCmd string
}{
{
name: "cmd flag passed as command",
args: []string{"hostname", "--cmd", "echo test"},
expectedCmd: "--cmd echo test",
},
{
name: "uid flag passed as command",
args: []string{"hostname", "--uid", "1000"},
expectedCmd: "--uid 1000",
},
{
name: "shell flag passed as command",
args: []string{"hostname", "--shell", "/bin/bash"},
expectedCmd: "--shell /bin/bash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
require.NoError(t, err)
assert.Equal(t, "hostname", host)
assert.Equal(t, tt.expectedCmd, command)
})
}
}
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
}
if err != nil {

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -117,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
@@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
@@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}

View File

@@ -18,12 +18,16 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
var (
ErrClientAlreadyStarted = errors.New("client already started")
ErrClientNotStarted = errors.New("client not started")
ErrEngineNotStarted = errors.New("engine not started")
ErrConfigNotInitialized = errors.New("config not initialized")
)
// Client manages a netbird embedded client instance.
type Client struct {
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
engine, err := c.getEngine()
if err != nil {
return nil, err
}
nsnet, err := engine.GetNet()
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// DialContext dials a network address in the netbird network with context
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return c.Dial(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
engine, err := c.getEngine()
if err != nil {
return err
}
storedKey, found := engine.GetPeerSSHKey(peerAddress)
if !found {
return sshcommon.ErrPeerNotFound
}
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.
func (c *Client) getEngine() (*internal.Engine, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
if connect == nil {
return nil, ErrClientNotStarted
}
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
return nil, ErrEngineNotStarted
}
return engine, nil
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
engine, err := c.getEngine()
if err != nil {
return nil, netip.Addr{}, err
}
addr, err := engine.Address()

View File

@@ -35,6 +35,12 @@ const (
ipTCPHeaderMinSize = 40
)
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
@@ -59,12 +65,6 @@ const (
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
return buf.Bytes()
}
func TestShouldForward(t *testing.T) {
// Set up test addresses
wgIP := netip.MustParseAddr("100.10.0.1")
otherIP := netip.MustParseAddr("100.10.0.2")
// Create test manager with mock interface
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// Set the mock to return our test WG IP
ifaceMock.AddressFunc = func() wgaddr.Address {
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
}
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Helper to create decoder with TCP packet
createTCPDecoder := func(dstPort uint16) *decoder {
ipv4 := &layers.IPv4{
Version: 4,
Protocol: layers.IPProtocolTCP,
SrcIP: net.ParseIP("192.168.1.100"),
DstIP: wgIP.AsSlice(),
}
tcp := &layers.TCP{
SrcPort: 54321,
DstPort: layers.TCPPort(dstPort),
}
err := tcp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
require.NoError(t, err)
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
require.NoError(t, err)
return d
}
tests := []struct {
name string
localForwarding bool
netstack bool
dstIP netip.Addr
serviceRegistered bool
servicePort uint16
expected bool
description string
}{
{
name: "no local forwarding",
localForwarding: false,
netstack: true,
dstIP: wgIP,
expected: false,
description: "should never forward when local forwarding disabled",
},
{
name: "traffic to other local interface",
localForwarding: true,
netstack: false,
dstIP: otherIP,
expected: true,
description: "should forward traffic to our other local interfaces (not NetBird IP)",
},
{
name: "traffic to NetBird IP, no netstack",
localForwarding: true,
netstack: false,
dstIP: wgIP,
expected: false,
description: "should send to netstack listeners (final return false path)",
},
{
name: "traffic to our IP, netstack mode, no service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
expected: true,
description: "should forward when in netstack mode with no matching service",
},
{
name: "traffic to our IP, netstack mode, with service",
localForwarding: true,
netstack: true,
dstIP: wgIP,
serviceRegistered: true,
servicePort: 22,
expected: false,
description: "should send to netstack listeners when service is registered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Configure manager
manager.localForwarding = tt.localForwarding
manager.netstack = tt.netstack
// Register service if needed
if tt.serviceRegistered {
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
}
// Create decoder for the test
decoder := createTCPDecoder(tt.servicePort)
if !tt.serviceRegistered {
decoder = createTCPDecoder(8080) // Use non-registered port
}
// Test the method
result := manager.shouldForward(decoder, tt.dstIP)
require.Equal(t, tt.expected, result, tt.description)
})
}
}

View File

@@ -0,0 +1,85 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestPortDNATBasic tests basic port DNAT functionality
func TestPortDNATBasic(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rule for peer B (the target)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
d := parsePacket(t, packetAtoB)
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
// Verify port was translated to 22022
d = parsePacket(t, packetAtoB)
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
// (prevents double NAT - original port stored in conntrack)
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
d2 := parsePacket(t, returnPacket)
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
}
// TestPortDNATMultipleRules tests multiple port DNAT rules
func TestPortDNATMultipleRules(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
// Define peer IPs
peerA := netip.MustParseAddr("100.10.0.50")
peerB := netip.MustParseAddr("100.10.0.51")
// Add SSH port redirection rules for both peers
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
require.NoError(t, err)
// Test traffic to peer B gets translated
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
d1 := parsePacket(t, packetToB)
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
require.True(t, translatedToB, "Traffic to peer B should be translated")
d1 = parsePacket(t, packetToB)
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
// Test traffic to peer A gets translated
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
d2 := parsePacket(t, packetToA)
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
require.True(t, translatedToA, "Traffic to peer A should be translated")
d2 = parsePacket(t, packetToA)
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
}

View File

@@ -17,7 +17,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled
// If SSH enabled, add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort),
})
}
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
// we have old version of management without rules handling, we should allow all traffic
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {

View File

@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
SshConfig: &mgmProto.SSHConfig{
SshEnabled: true,
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false)
expectedRules := 3
if fw.IsStateful() {
expectedRules = 3 // 2 inbound rules + SSH rule
}
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
}

View File

@@ -416,20 +416,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address,
IFaceBlackList: config.IFaceBlackList,
DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key,
WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
@@ -515,6 +520,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -453,6 +453,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
}
if g.internalConfig.EnableSSHRoot != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
}
if g.internalConfig.EnableSSHSFTP != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
}
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
}
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))

View File

@@ -9,7 +9,6 @@ import (
"net/netip"
"net/url"
"os"
"reflect"
"runtime"
"slices"
"sort"
@@ -30,7 +29,6 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -51,10 +49,10 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
@@ -115,7 +113,12 @@ type EngineConfig struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerSSHAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DNSRouteInterval time.Duration
@@ -148,8 +151,6 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex
// sshMux protects sshServer field access
sshMux sync.Mutex
config *EngineConfig
mobileDep MobileDependency
@@ -175,8 +176,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
sshServer sshServer
statusRecorder *peer.Status
@@ -246,7 +246,6 @@ func NewEngine(
STUNs: []*stun.URI{},
TURNs: []*stun.URI{},
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
@@ -268,6 +267,7 @@ func NewEngine(
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
@@ -292,6 +292,12 @@ func (e *Engine) Stop() error {
}
log.Info("Network monitor: stopped")
if err := e.stopSSHServer(); err != nil {
log.Warnf("failed to stop SSH server: %v", err)
}
e.cleanupSSHConfig()
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer()
@@ -703,16 +709,10 @@ func (e *Engine) removeAllPeers() error {
return nil
}
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
// removePeer closes an existing peer connection and removes a peer
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)
e.sshMux.Lock()
if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.sshMux.Unlock()
e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey)
@@ -884,6 +884,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -893,74 +898,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
return nil
}
func isNil(server nbssh.Server) bool {
return server == nil || reflect.ValueOf(server).IsNil()
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Infof("SSH server is disabled because inbound connections are blocked")
return nil
}
if !e.config.ServerSSHAllowed {
log.Info("SSH server is not enabled")
return nil
}
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
e.sshMux.Lock()
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if nbnetstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
e.sshMux.Unlock()
return fmt.Errorf("create ssh server: %w", err)
}
e.sshServer = server
e.sshMux.Unlock()
go func() {
// blocking
err = server.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.sshMux.Lock()
e.sshServer = nil
e.sshMux.Unlock()
log.Infof("stopped SSH server")
}()
} else {
e.sshMux.Unlock()
log.Debugf("SSH server is already running")
}
} else {
e.sshMux.Lock()
if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
e.sshMux.Unlock()
}
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
@@ -973,8 +910,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
if conf.GetSshConfig() != nil {
err := e.updateSSH(conf.GetSshConfig())
if err != nil {
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
log.Warnf("failed handling SSH server setup: %v", err)
}
}
@@ -1012,6 +948,11 @@ func (e *Engine) receiveManagementEvents() {
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1170,19 +1111,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
e.sshMux.Lock()
if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
if err != nil {
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
}
}
}
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.sshMux.Unlock()
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1544,15 +1477,6 @@ func (e *Engine) close() {
e.statusRecorder.SetWgIface(nil)
}
e.sshMux.Lock()
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed stopping the SSH server: %v", err)
}
}
e.sshMux.Unlock()
if e.firewall != nil {
err := e.firewall.Close(e.stateManager)
if err != nil {
@@ -1583,6 +1507,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -0,0 +1,355 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
"strings"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
GetStatus() (bool, []sshserver.SessionInfo)
}
func (e *Engine) setupSSHPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
return fmt.Errorf("add SSH port redirection: %w", err)
}
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked")
return e.stopSSHServer()
}
if !e.config.ServerSSHAllowed {
log.Info("SSH server is disabled in config")
return e.stopSSHServer()
}
if !sshConf.GetSshEnabled() {
if e.config.ServerSSHAllowed {
log.Info("SSH server is locally allowed but disabled by management server")
}
return e.stopSSHServer()
}
if e.sshServer != nil {
log.Debug("SSH server is already running")
return nil
}
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
return e.startSSHServer(nil)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
return e.startSSHServer(jwtConfig)
}
return errors.New("SSH server requires valid JWT configuration")
}
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
return nil
}
configMgr := sshconfig.New()
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
return nil // Don't fail engine startup on SSH config issues
}
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
SSHConfigDir: configMgr.GetSSHConfigDir(),
SSHConfigFile: configMgr.GetSSHConfigFile(),
}); err != nil {
log.Warnf("failed to update SSH config state: %v", err)
}
return nil
}
// extractPeerSSHInfo extracts SSH information from peer configurations
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
var peerInfo []sshconfig.PeerSSHInfo
for _, peerConfig := range remotePeers {
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
})
}
return peerInfo
}
// extractPeerIP extracts IP address from peer's allowed IPs
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
if len(peerConfig.GetAllowedIps()) == 0 {
return ""
}
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
return prefix.Addr().String()
}
return ""
}
// extractHostname extracts short hostname from FQDN
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
fqdn := peerConfig.GetFqdn()
if fqdn == "" {
return ""
}
parts := strings.Split(fqdn, ".")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers {
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
}
}
log.Debugf("updated peer SSH host keys for daemon API access")
}
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
e.syncMsgMux.Lock()
statusRecorder := e.statusRecorder
e.syncMsgMux.Unlock()
if statusRecorder == nil {
return nil, false
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
return nil, false
}
}
return nil, false
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err)
} else {
log.Debugf("SSH client config cleanup completed")
}
}
// startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
e.sshServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
}
}
if err := e.setupSSHPortRedirection(); err != nil {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
return nil
}
// configureSSHServer applies SSH configuration options to the server.
func (e *Engine) configureSSHServer(server *sshserver.Server) {
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
server.SetAllowRootLogin(true)
log.Info("SSH root login enabled")
} else {
server.SetAllowRootLogin(false)
log.Info("SSH root login disabled (default)")
}
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
server.SetAllowSFTP(true)
log.Info("SSH SFTP subsystem enabled")
} else {
server.SetAllowSFTP(false)
log.Info("SSH SFTP subsystem disabled (default)")
}
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
server.SetAllowLocalPortForwarding(true)
log.Info("SSH local port forwarding enabled")
} else {
server.SetAllowLocalPortForwarding(false)
log.Info("SSH local port forwarding disabled (default)")
}
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
server.SetAllowRemotePortForwarding(true)
log.Info("SSH remote port forwarding enabled")
} else {
server.SetAllowRemotePortForwarding(false)
log.Info("SSH remote port forwarding disabled (default)")
}
}
func (e *Engine) cleanupSSHPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
return fmt.Errorf("remove SSH port redirection: %w", err)
}
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
return nil
}
func (e *Engine) stopSSHServer() error {
if e.sshServer == nil {
return nil
}
if err := e.cleanupSSHPortRedirection(); err != nil {
log.Warnf("failed to cleanup SSH port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
}
}
log.Info("stopping SSH server")
err := e.sshServer.Stop()
e.sshServer = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
// GetSSHServerStatus returns the SSH server status and active sessions
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
e.syncMsgMux.Lock()
sshServer := e.sshServer
e.syncMsgMux.Unlock()
if sshServer == nil {
return false, nil
}
return sshServer.GetStatus()
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -25,7 +24,10 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -46,7 +48,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
@@ -214,11 +216,13 @@ func TestMain(m *testing.M) {
}
func TestEngine_SSH(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH")
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
key, err := wgtypes.GeneratePrivateKey()
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
@@ -240,6 +244,7 @@ func TestEngine_SSH(t *testing.T) {
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
@@ -250,35 +255,8 @@ func TestEngine_SSH(t *testing.T) {
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
var sshKeysAdded []string
var sshPeersRemoved []string
sshCtx, cancel := context.WithCancel(context.Background())
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
return &ssh.MockServer{
Ctx: sshCtx,
StopFunc: func() error {
cancel()
return nil
},
StartFunc: func() error {
<-ctx.Done()
return ctx.Err()
},
AddAuthorizedKeyFunc: func(peer, newKey string) error {
sshKeysAdded = append(sshKeysAdded, newKey)
return nil
},
RemoveAuthorizedKeyFunc: func(peer string) {
sshPeersRemoved = append(sshPeersRemoved, peer)
},
}, nil
}
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer func() {
err := engine.Stop()
@@ -304,9 +282,7 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
@@ -314,19 +290,24 @@ func TestEngine_SSH(t *testing.T) {
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
// now remove peer
networkMap = &mgmtProto.NetworkMap{
@@ -336,13 +317,10 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
@@ -354,12 +332,70 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: false, // Start with SSH disabled
},
syncMsgMux: &sync.Mutex{},
}
// Test SSH disabled config
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
err := engine.updateSSH(sshConfig)
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
// Test inbound blocked
engine.config.BlockInbound = true
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
engine.config.BlockInbound = false
// Test with server SSH not allowed
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHServerConsistency(t *testing.T) {
t.Run("server set only on successful creation", func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: true,
SSHKey: []byte("test-key"),
},
syncMsgMux: &sync.Mutex{},
}
engine.wgInterface = nil
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.Error(t, err)
assert.Nil(t, engine.sshServer)
})
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: false,
},
syncMsgMux: &sync.Mutex{},
}
err := engine.stopSSHServer()
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
})
}
func TestEngine_UpdateNetworkMap(t *testing.T) {
@@ -1589,7 +1625,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}

View File

@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
}
}()
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}

View File

@@ -2,6 +2,7 @@ package peer
import (
"os"
"runtime"
"strings"
)
@@ -10,5 +11,8 @@ const (
)
func isForceRelayed() bool {
if runtime.GOOS == "js" {
return true
}
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10
@@ -67,6 +67,7 @@ type State struct {
BytesRx int64
Latency time.Duration
RosenpassEnabled bool
SSHHostKey []byte
routes map[string]struct{}
}
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
return nil
}
// UpdatePeerSSHHostKey updates peer's SSH host key
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peerPubKey]
if !ok {
return errors.New("peer doesn't exist")
}
peerState.SSHHostKey = sshHostKey
d.peers[peerPubKey] = peerState
return nil
}
// FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() {
d.mux.Lock()

View File

@@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
ManagementURL string
AdminURL string
ConfigPath string
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
DisableClientRoutes *bool
DisableServerRoutes *bool
@@ -82,18 +88,24 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
ManagementURL *url.URL
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
PrivateKey string
PreSharedKey string
ManagementURL *url.URL
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
DisableServerRoutes bool
@@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
} else {
log.Infof("disabling SSH root login")
}
config.EnableSSHRoot = input.EnableSSHRoot
updated = true
}
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
if *input.EnableSSHSFTP {
log.Infof("enabling SSH SFTP subsystem")
} else {
log.Infof("disabling SSH SFTP subsystem")
}
config.EnableSSHSFTP = input.EnableSSHSFTP
updated = true
}
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
if *input.EnableSSHLocalPortForwarding {
log.Infof("enabling SSH local port forwarding")
} else {
log.Infof("disabling SSH local port forwarding")
}
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
updated = true
}
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
if *input.EnableSSHRemotePortForwarding {
log.Infof("enabling SSH remote port forwarding")
} else {
log.Infof("disabling SSH remote port forwarding")
}
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
updated = true
}
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
log.Infof("enabling SSH authentication")
}
config.DisableSSHAuth = input.DisableSSHAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())

View File

@@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) {
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
tests := []struct {
name string
wireguardPort *int
expectedPort int
description string
name string
wireguardPort *int
expectedPort int
description string
}{
{
name: "no port specified uses default",

View File

@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
return nil
}
// GetLoginHint retrieves the email from the active profile to use as login_hint.
func GetLoginHint() string {
pm := NewProfileManager()
activeProf, err := pm.GetActiveProfile()
if err != nil {
log.Debugf("failed to get active profile for login hint: %v", err)
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""
}
return profileState.Email
}

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -39,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/version"

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile

File diff suppressed because it is too large Load Diff

View File

@@ -84,6 +84,15 @@ service DaemonService {
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
// RequestJWTAuth initiates JWT authentication flow for SSH
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
}
@@ -161,6 +170,13 @@ message LoginRequest {
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
optional bool enableSSHRoot = 34;
optional bool enableSSHSFTP = 35;
optional bool enableSSHLocalPortForwarding = 36;
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
}
message LoginResponse {
@@ -188,9 +204,9 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3;
optional bool waitForReady = 3;
}
message StatusResponse{
@@ -255,6 +271,18 @@ message GetConfigResponse {
bool disable_server_routes = 19;
bool block_lan_access = 20;
bool enableSSHRoot = 21;
bool enableSSHSFTP = 24;
bool enableSSHLocalPortForwarding = 22;
bool enableSSHRemotePortForwarding = 23;
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
}
// PeerState contains the latest state of a peer
@@ -276,6 +304,7 @@ message PeerState {
repeated string networks = 16;
google.protobuf.Duration latency = 17;
string relayAddress = 18;
bytes sshHostKey = 19;
}
// LocalPeerState contains the latest state of the local peer
@@ -317,6 +346,20 @@ message NSGroupState {
string error = 4;
}
// SSHSessionInfo contains information about an active SSH session
message SSHSessionInfo {
string username = 1;
string remoteAddress = 2;
string command = 3;
string jwtUsername = 4;
}
// SSHServerState contains the latest state of the SSH server
message SSHServerState {
bool enabled = 1;
repeated SSHSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -330,6 +373,7 @@ message FullStatus {
repeated SystemEvent events = 7;
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
}
// Networks
@@ -543,56 +587,63 @@ message SwitchProfileRequest {
message SwitchProfileResponse {}
message SetConfigRequest {
string username = 1;
string profileName = 2;
// managementUrl to authenticate.
string managementUrl = 3;
string username = 1;
string profileName = 2;
// managementUrl to authenticate.
string managementUrl = 3;
// adminUrl to manage keys.
string adminURL = 4;
// adminUrl to manage keys.
string adminURL = 4;
optional bool rosenpassEnabled = 5;
optional bool rosenpassEnabled = 5;
optional string interfaceName = 6;
optional string interfaceName = 6;
optional int64 wireguardPort = 7;
optional int64 wireguardPort = 7;
optional string optionalPreSharedKey = 8;
optional string optionalPreSharedKey = 8;
optional bool disableAutoConnect = 9;
optional bool disableAutoConnect = 9;
optional bool serverSSHAllowed = 10;
optional bool serverSSHAllowed = 10;
optional bool rosenpassPermissive = 11;
optional bool rosenpassPermissive = 11;
optional bool networkMonitor = 12;
optional bool networkMonitor = 12;
optional bool disable_client_routes = 13;
optional bool disable_server_routes = 14;
optional bool disable_dns = 15;
optional bool disable_firewall = 16;
optional bool block_lan_access = 17;
optional bool disable_client_routes = 13;
optional bool disable_server_routes = 14;
optional bool disable_dns = 15;
optional bool disable_firewall = 16;
optional bool block_lan_access = 17;
optional bool disable_notifications = 18;
optional bool disable_notifications = 18;
optional bool lazyConnectionEnabled = 19;
optional bool lazyConnectionEnabled = 19;
optional bool block_inbound = 20;
optional bool block_inbound = 20;
repeated string natExternalIPs = 21;
bool cleanNATExternalIPs = 22;
repeated string natExternalIPs = 21;
bool cleanNATExternalIPs = 22;
bytes customDNSAddress = 23;
bytes customDNSAddress = 23;
repeated string extraIFaceBlacklist = 24;
repeated string extraIFaceBlacklist = 24;
repeated string dns_labels = 25;
// cleanDNSLabels clean map list of DNS labels.
bool cleanDNSLabels = 26;
repeated string dns_labels = 25;
// cleanDNSLabels clean map list of DNS labels.
bool cleanDNSLabels = 26;
optional google.protobuf.Duration dnsRouteInterval = 27;
optional google.protobuf.Duration dnsRouteInterval = 27;
optional int64 mtu = 28;
optional int64 mtu = 28;
optional bool enableSSHRoot = 29;
optional bool enableSSHSFTP = 30;
optional bool enableSSHLocalPortForwarding = 31;
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
}
message SetConfigResponse{}
@@ -644,3 +695,63 @@ message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
}
// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
message GetPeerSSHHostKeyRequest {
// peer IP address or FQDN to get SSH host key for
string peerAddress = 1;
}
// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
message GetPeerSSHHostKeyResponse {
// SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
bytes sshHostKey = 1;
// peer IP address
string peerIP = 2;
// peer FQDN
string peerFQDN = 3;
// indicates if the SSH host key was found
bool found = 4;
}
// RequestJWTAuthRequest for initiating JWT authentication flow
message RequestJWTAuthRequest {
// hint for OIDC login_hint parameter (typically email address)
optional string hint = 1;
}
// RequestJWTAuthResponse contains authentication flow information
message RequestJWTAuthResponse {
// verification URI for user authentication
string verificationURI = 1;
// complete verification URI (with embedded user code)
string verificationURIComplete = 2;
// user code to enter on verification URI
string userCode = 3;
// device code for polling
string deviceCode = 4;
// expiration time in seconds
int64 expiresIn = 5;
// if a cached token is available, it will be returned here
string cachedToken = 6;
// maximum age of JWT tokens in seconds (from management server)
int64 maxTokenAge = 7;
}
// WaitJWTTokenRequest for waiting for authentication completion
message WaitJWTTokenRequest {
// device code from RequestJWTAuthResponse
string deviceCode = 1;
// user code for verification
string userCode = 2;
}
// WaitJWTTokenResponse contains the JWT token after authentication
message WaitJWTTokenResponse {
// JWT token (access token or ID token)
string token = 1;
// token type (e.g., "Bearer")
string tokenType = 2;
// expiration time in seconds
int64 expiresIn = 3;
}

View File

@@ -64,6 +64,12 @@ type DaemonServiceClient interface {
// Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
}
type daemonServiceClient struct {
@@ -349,6 +355,33 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe
return out, nil
}
func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) {
out := new(GetPeerSSHHostKeyResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
out := new(RequestJWTAuthResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
out := new(WaitJWTTokenResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -399,6 +432,12 @@ type DaemonServiceServer interface {
// Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -490,6 +529,15 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest)
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
}
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
}
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
}
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1010,6 +1058,60 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetPeerSSHHostKeyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RequestJWTAuthRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(WaitJWTTokenRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/WaitJWTToken",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1125,6 +1227,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetFeatures",
Handler: _DaemonService_GetFeatures_Handler,
},
{
MethodName: "GetPeerSSHHostKey",
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
},
{
MethodName: "RequestJWTAuth",
Handler: _DaemonService_RequestJWTAuth_Handler,
},
{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -0,0 +1,79 @@
package server
import (
"sync"
"time"
"github.com/awnumar/memguard"
log "github.com/sirupsen/logrus"
)
type jwtCache struct {
mu sync.RWMutex
enclave *memguard.Enclave
expiresAt time.Time
timer *time.Timer
maxTokenSize int
}
func newJWTCache() *jwtCache {
return &jwtCache{
maxTokenSize: 8192,
}
}
func (c *jwtCache) store(token string, maxAge time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanup()
if c.timer != nil {
c.timer.Stop()
}
tokenBytes := []byte(token)
c.enclave = memguard.NewEnclave(tokenBytes)
c.expiresAt = time.Now().Add(maxAge)
var timer *time.Timer
timer = time.AfterFunc(maxAge, func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.timer != timer {
return
}
c.cleanup()
c.timer = nil
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
})
c.timer = timer
}
func (c *jwtCache) get() (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.enclave == nil || time.Now().After(c.expiresAt) {
return "", false
}
buffer, err := c.enclave.Open()
if err != nil {
log.Debugf("Failed to open JWT token enclave: %v", err)
return "", false
}
defer buffer.Destroy()
token := string(buffer.Bytes())
return token, true
}
// cleanup destroys the secure enclave, must be called with lock held
func (c *jwtCache) cleanup() {
if c.enclave != nil {
c.enclave = nil
}
c.expiresAt = time.Time{}
}

View File

@@ -11,8 +11,8 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
type selectRoute struct {

View File

@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon (disabled by default)
defaultJWTCacheTTL = 0
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,8 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
jwtCache *jwtCache
}
type oauthAuthFlow struct {
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
}
@@ -373,6 +379,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
config.EnableSSHRoot = msg.EnableSSHRoot
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -493,7 +510,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, err
}
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{
@@ -510,7 +527,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err
@@ -1065,12 +1082,235 @@ func (s *Server) Status(
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
return &statusResponse, nil
}
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
func (s *Server) getSSHServerState() *proto.SSHServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetSSHServerStatus()
sshServerState := &proto.SSHServerState{
Enabled: enabled,
}
for _, session := range sessions {
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
Username: session.Username,
RemoteAddress: session.RemoteAddress,
Command: session.Command,
JwtUsername: session.JWTUsername,
})
}
return sshServerState
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
req *proto.GetPeerSSHHostKeyRequest,
) (*proto.GetPeerSSHHostKeyResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
connectClient := s.connectClient
statusRecorder := s.statusRecorder
s.mutex.Unlock()
if connectClient == nil {
return nil, errors.New("client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
peerAddress := req.GetPeerAddress()
hostKey, found := engine.GetPeerSSHKey(peerAddress)
response := &proto.GetPeerSSHHostKeyResponse{
Found: found,
}
if !found {
return response, nil
}
response.SshHostKey = hostKey
if statusRecorder == nil {
return response, nil
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
break
}
}
return response, nil
}
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
func (s *Server) getJWTCacheTTL() time.Duration {
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil || config.SSHJWTCacheTTL == nil {
return defaultJWTCacheTTL
}
seconds := *config.SSHJWTCacheTTL
if seconds == 0 {
log.Debug("SSH JWT cache disabled (configured to 0)")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
msg *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := s.getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
return &proto.RequestJWTAuthResponse{
CachedToken: cachedToken,
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
}
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
if hint == "" {
hint = profilemanager.GetLoginHint()
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
}
s.mutex.Lock()
s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
return &proto.RequestJWTAuthResponse{
VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: authInfo.UserCode,
DeviceCode: authInfo.DeviceCode,
ExpiresIn: int64(authInfo.ExpiresIn),
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
// WaitJWTToken waits for JWT authentication completion
func (s *Server) WaitJWTToken(
ctx context.Context,
req *proto.WaitJWTTokenRequest,
) (*proto.WaitJWTTokenResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
oAuthFlow := s.oauthAuthFlow.flow
authInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
}
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
}
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := s.getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
} else {
log.Debug("JWT caching disabled, not storing token")
}
s.mutex.Lock()
s.oauthAuthFlow = oauthAuthFlow{}
s.mutex.Unlock()
return &proto.WaitJWTTokenResponse{
Token: tokenInfo.GetTokenToUse(),
TokenType: tokenInfo.TokenType,
ExpiresIn: int64(tokenInfo.ExpiresIn),
}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil {
return
@@ -1136,25 +1376,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableServerRoutes := cfg.DisableServerRoutes
blockLANAccess := cfg.BlockLANAccess
enableSSHRoot := false
if cfg.EnableSSHRoot != nil {
enableSSHRoot = *cfg.EnableSSHRoot
}
enableSSHSFTP := false
if cfg.EnableSSHSFTP != nil {
enableSSHSFTP = *cfg.EnableSSHSFTP
}
enableSSHLocalPortForwarding := false
if cfg.EnableSSHLocalPortForwarding != nil {
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
}
enableSSHRemotePortForwarding := false
if cfg.EnableSSHRemotePortForwarding != nil {
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
}
disableSSHAuth := false
if cfg.DisableSSHAuth != nil {
disableSSHAuth = *cfg.DisableSSHAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
}
return &proto.GetConfigResponse{
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
BlockLanAccess: blockLANAccess,
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,
DisableDns: disableDNS,
DisableClientRoutes: disableClientRoutes,
DisableServerRoutes: disableServerRoutes,
BlockLanAccess: blockLANAccess,
EnableSSHRoot: enableSSHRoot,
EnableSSHSFTP: enableSSHSFTP,
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}
@@ -1385,6 +1661,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -316,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}

View File

@@ -72,6 +72,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
sshJWTCacheTTL := int32(300)
req := &proto.SetConfigRequest{
ProfileName: profName,
@@ -102,6 +103,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
}
_, err = s.SetConfig(ctx, req)
@@ -146,6 +148,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
require.NotNil(t, cfg.SSHJWTCacheTTL)
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
verifyAllFieldsCovered(t, req)
}
@@ -167,30 +171,36 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
"EnableSSHRoot": true,
"EnableSSHSFTP": true,
"EnableSSHLocalPortForwarding": true,
"EnableSSHRemotePortForwarding": true,
"DisableSSHAuth": true,
"SshJWTCacheTTL": true,
}
val := reflect.ValueOf(req).Elem()
@@ -221,29 +231,35 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// Map of CLI flag names to their corresponding SetConfigRequest field names.
// This map must be updated when adding new config-related CLI flags.
flagToField := map[string]string{
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
"enable-ssh-root": "EnableSSHRoot",
"enable-ssh-sftp": "EnableSSHSFTP",
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
"disable-ssh-auth": "DisableSSHAuth",
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).

View File

@@ -6,9 +6,11 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&nftables.ShutdownState{})
mgr.RegisterState(&iptables.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
}

View File

@@ -1,118 +0,0 @@
//go:build !js
package ssh
import (
"fmt"
"net"
"os"
"time"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
// Client wraps crypto/ssh Client to simplify usage
type Client struct {
client *ssh.Client
}
// Close closes the wrapped SSH Client
func (c *Client) Close() error {
return c.client.Close()
}
// OpenTerminal starts an interactive terminal session with the remote SSH server
func (c *Client) OpenTerminal() error {
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("failed to open new session: %v", err)
}
defer func() {
err := session.Close()
if err != nil {
return
}
}()
fd := int(os.Stdout.Fd())
state, err := term.MakeRaw(fd)
if err != nil {
return fmt.Errorf("failed to run raw terminal: %s", err)
}
defer func() {
err := term.Restore(fd, state)
if err != nil {
return
}
}()
w, h, err := term.GetSize(fd)
if err != nil {
return fmt.Errorf("terminal get size: %s", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("failed requesting pty session with xterm: %s", err)
}
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
if err := session.Shell(); err != nil {
return fmt.Errorf("failed to start login shell on the remote host: %s", err)
}
if err := session.Wait(); err != nil {
if e, ok := err.(*ssh.ExitError); ok {
if e.ExitStatus() == 130 {
return nil
}
}
return fmt.Errorf("failed running SSH session: %s", err)
}
return nil
}
// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
config := &ssh.ClientConfig{
User: user,
Timeout: 5 * time.Second,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
}
return Dial("tcp", addr, config)
}
// Dial connects to the remote SSH server.
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
client, err := ssh.Dial(network, addr, config)
if err != nil {
return nil, err
}
return &Client{
client: client,
}, nil
}

699
client/ssh/client/client.go Normal file
View File

@@ -0,0 +1,699 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
"golang.org/x/term"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
)
const (
// DefaultDaemonAddr is the default address for the NetBird daemon
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
)
// Client wraps crypto/ssh Client for simplified SSH operations
type Client struct {
client *ssh.Client
terminalState *term.State
terminalFd int
windowsStdoutMode uint32 // nolint:unused
windowsStdinMode uint32 // nolint:unused
}
func (c *Client) Close() error {
return c.client.Close()
}
func (c *Client) OpenTerminal(ctx context.Context) error {
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("new session: %w", err)
}
defer func() {
if err := session.Close(); err != nil {
log.Debugf("session close error: %v", err)
}
}()
if err := c.setupTerminalMode(ctx, session); err != nil {
return err
}
c.setupSessionIO(session)
if err := session.Shell(); err != nil {
return fmt.Errorf("start shell: %w", err)
}
return c.waitForSession(ctx, session)
}
// setupSessionIO connects session streams to local terminal
func (c *Client) setupSessionIO(session *ssh.Session) {
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
}
// waitForSession waits for the session to complete with context cancellation
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
defer c.restoreTerminal()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return c.handleSessionError(err)
}
}
// handleSessionError processes session termination errors
func (c *Client) handleSessionError(err error) error {
if err == nil {
return nil
}
var e *ssh.ExitError
var em *ssh.ExitMissingError
if !errors.As(err, &e) && !errors.As(err, &em) {
return fmt.Errorf("session wait: %w", err)
}
return nil
}
// restoreTerminal restores the terminal to its original state
func (c *Client) restoreTerminal() {
if c.terminalState != nil {
_ = term.Restore(c.terminalFd, c.terminalState)
c.terminalState = nil
c.terminalFd = 0
}
if err := c.restoreWindowsConsoleState(); err != nil {
log.Debugf("restore Windows console state: %v", err)
}
}
// ExecuteCommand executes a command on the remote host and returns the output
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
session, cleanup, err := c.createSession(ctx)
if err != nil {
return nil, err
}
defer cleanup()
output, err := session.CombinedOutput(command)
if err != nil {
var e *ssh.ExitError
var em *ssh.ExitMissingError
if !errors.As(err, &e) && !errors.As(err, &em) {
return output, fmt.Errorf("execute command: %w", err)
}
}
return output, nil
}
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
session, cleanup, err := c.createSession(ctx)
if err != nil {
return fmt.Errorf("create session: %w", err)
}
defer cleanup()
c.setupSessionIO(session)
if err := session.Start(command); err != nil {
return fmt.Errorf("start command: %w", err)
}
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGTERM)
select {
case <-done:
return ctx.Err()
case <-time.After(100 * time.Millisecond):
return ctx.Err()
}
case err := <-done:
return c.handleCommandError(err)
}
}
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
session, cleanup, err := c.createSession(ctx)
if err != nil {
return fmt.Errorf("create session: %w", err)
}
defer cleanup()
if err := c.setupTerminalMode(ctx, session); err != nil {
return fmt.Errorf("setup terminal mode: %w", err)
}
c.setupSessionIO(session)
if err := session.Start(command); err != nil {
return fmt.Errorf("start command: %w", err)
}
defer c.restoreTerminal()
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGTERM)
select {
case <-done:
return ctx.Err()
case <-time.After(100 * time.Millisecond):
return ctx.Err()
}
case err := <-done:
return c.handleCommandError(err)
}
}
// handleCommandError processes command execution errors
func (c *Client) handleCommandError(err error) error {
if err == nil {
return nil
}
var e *ssh.ExitError
var em *ssh.ExitMissingError
if errors.As(err, &e) || errors.As(err, &em) {
return err
}
return fmt.Errorf("execute command: %w", err)
}
// setupContextCancellation sets up context cancellation for a session
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGTERM)
_ = session.Close()
case <-done:
}
}()
return func() { close(done) }
}
// createSession creates a new SSH session with context cancellation setup
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
session, err := c.client.NewSession()
if err != nil {
return nil, nil, fmt.Errorf("new session: %w", err)
}
cancel := c.setupContextCancellation(ctx, session)
cleanup := func() {
cancel()
_ = session.Close()
}
return session, cleanup, nil
}
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
func getDefaultDaemonAddr() string {
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
return addr
}
if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows
}
return DefaultDaemonAddr
}
// DialOptions contains options for SSH connections
type DialOptions struct {
KnownHostsFile string
IdentityFile string
DaemonAddr string
SkipCachedToken bool
InsecureSkipVerify bool
}
// Dial connects to the given ssh server with specified options
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
opts.DaemonAddr = daemonAddr
hostKeyCallback, err := createHostKeyCallback(opts)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
}
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: hostKeyCallback,
}
if opts.IdentityFile != "" {
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
if err != nil {
return nil, fmt.Errorf("create SSH key auth: %w", err)
}
config.Auth = append(config.Auth, authMethod)
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
}
// dialSSH establishes an SSH connection without JWT authentication
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("dial %s: %w", addr, err)
}
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("connection close after handshake failure: %v", closeErr)
}
return nil, fmt.Errorf("ssh handshake: %w", err)
}
client := ssh.NewClient(clientConn, chans, reqs)
return &Client{
client: client,
}, nil
}
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("parse address %s: %w", addr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
return nil, fmt.Errorf("SSH server detection failed: %w", err)
}
if !serverType.RequiresJWT() {
return dialSSH(ctx, network, addr, config)
}
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
if err != nil {
return nil, fmt.Errorf("request JWT token: %w", err)
}
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
return dialSSH(ctx, network, addr, configWithJWT)
}
// requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
hint := profilemanager.GetLoginHint()
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return "", fmt.Errorf("connect to daemon: %w", err)
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
}
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
verifier := nbssh.NewDaemonHostKeyVerifier(client)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
addr := strings.TrimPrefix(daemonAddr, "tcp://")
conn, err := grpc.NewClient(
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
}
return conn, nil
}
// getKnownHostsFiles returns paths to known_hosts files in order of preference
func getKnownHostsFiles() []string {
var files []string
// User's known_hosts file (highest priority)
if homeDir, err := os.UserHomeDir(); err == nil {
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
files = append(files, userKnownHosts)
}
// NetBird managed known_hosts files
if runtime.GOOS == "windows" {
programData := os.Getenv("PROGRAMDATA")
if programData == "" {
programData = `C:\ProgramData`
}
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
files = append(files, netbirdKnownHosts)
} else {
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
files = append(files, "/etc/ssh/ssh_known_hosts")
}
return files
}
// createHostKeyCallback creates a host key verification callback
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
if opts.InsecureSkipVerify {
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil
}
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
}, nil
}
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
if daemonAddr == "" {
return fmt.Errorf("no daemon address")
}
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
}
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
}
func getKnownHostsFilesList(knownHostsFile string) []string {
if knownHostsFile != "" {
return []string{knownHostsFile}
}
return getKnownHostsFiles()
}
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
return hostKeyCallbacks
}
// createSSHKeyAuth creates SSH key authentication from a private key file
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
}
signer, err := ssh.ParsePrivateKey(keyData)
if err != nil {
return nil, fmt.Errorf("parse SSH private key: %w", err)
}
return ssh.PublicKeys(signer), nil
}
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
localListener, err := net.Listen("tcp", localAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", localAddr, err)
}
go func() {
defer func() {
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("local listener close error: %v", err)
}
}()
for {
localConn, err := localListener.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
continue
}
go c.handleLocalForward(localConn, remoteAddr)
}
}()
<-ctx.Done()
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("local listener close error: %v", err)
}
return ctx.Err()
}
// handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
}
}()
channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil {
if strings.Contains(err.Error(), "administratively prohibited") {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
} else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
}
return
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
}
}()
go func() {
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("local forward copy error (local->remote): %v", err)
}
}()
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("local forward copy error (remote->local): %v", err)
}
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
host, port, err := c.parseRemoteAddress(remoteAddr)
if err != nil {
return fmt.Errorf("parse remote address: %w", err)
}
req := c.buildTCPIPForwardRequest(host, port)
if err := c.sendTCPIPForwardRequest(req); err != nil {
return fmt.Errorf("setup remote forward: %w", err)
}
go c.handleRemoteForwardChannels(ctx, localAddr)
<-ctx.Done()
if err := c.cancelTCPIPForwardRequest(req); err != nil {
return fmt.Errorf("cancel tcpip-forward: %w", err)
}
return ctx.Err()
}
// parseRemoteAddress parses host and port from remote address string
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
host, portStr, err := net.SplitHostPort(remoteAddr)
if err != nil {
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
}
return host, uint32(port), nil
}
// buildTCPIPForwardRequest creates a tcpip-forward request message
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
return tcpipForwardMsg{
Host: host,
Port: port,
}
}
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
if err != nil {
return fmt.Errorf("send tcpip-forward request: %w", err)
}
if !ok {
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
}
return nil
}
// cancelTCPIPForwardRequest cancels the tcpip-forward request
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
if err != nil {
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
}
return nil
}
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
// Get the channel once - subsequent calls return nil!
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
if channelRequests == nil {
log.Debugf("forwarded-tcpip channel type already being handled")
return
}
for {
select {
case <-ctx.Done():
return
case newChan := <-channelRequests:
if newChan != nil {
go c.handleRemoteForwardChannel(newChan, localAddr)
}
}
}
}
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept()
if err != nil {
return
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
}
}()
go ssh.DiscardRequests(reqs)
localConn, err := net.Dial("tcp", localAddr)
if err != nil {
return
}
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
}
}()
go func() {
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("remote forward copy error (remote->local): %v", err)
}
}()
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("remote forward copy error (local->remote): %v", err)
}
}
// tcpipForwardMsg represents the structure for tcpip-forward requests
type tcpipForwardMsg struct {
Host string
Port uint32
}

View File

@@ -0,0 +1,512 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers()
os.Exit(code)
}
func TestSSHClient_DialWithKey(t *testing.T) {
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create and start server
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Test Dial
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
// Verify client is connected
assert.NotNil(t, client.client)
}
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ConnectionHandling(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Generate client key for multiple connections
const numClients = 3
clients := make([]*Client, numClients)
currentUser := testutil.GetTestUsername(t)
for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
cancel()
require.NoError(t, err, "Client %d should connect successfully", i)
clients[i] = client
}
for i, client := range clients {
err := client.Close()
assert.NoError(t, err, "Client %d should close without error", i)
}
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}
func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
assert.NoError(t, err)
if client != nil {
require.NoError(t, client.Close(), "Client should close without error")
}
})
}
func TestSSHClient_TerminalState(t *testing.T) {
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
assert.Nil(t, client.terminalState)
assert.Equal(t, 0, client.terminalFd)
client.restoreTerminal()
assert.Nil(t, client.terminalState)
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
err := client.OpenTerminal(ctx)
// In test environment without a real terminal, this may complete quickly or timeout
// Both behaviors are acceptable for testing terminal state management
if err != nil {
if runtime.GOOS == "windows" {
assert.True(t,
strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "console"),
"Should timeout or have console error on Windows")
} else {
// On Unix systems in test environment, we may get various errors
// including timeouts or terminal-related errors
assert.True(t,
strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "terminal") ||
strings.Contains(err.Error(), "pty"),
"Expected timeout or terminal-related error, got: %v", err)
}
}
}
func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
return server, serverAddr, client
}
func TestSSHClient_PortForwarding(t *testing.T) {
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
t.Run("local forwarding times out gracefully", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
assert.Error(t, err)
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "connection"),
"Expected context or connection error")
})
t.Run("remote forwarding denied", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
assert.Error(t, err)
assert.True(t,
strings.Contains(err.Error(), "denied") ||
strings.Contains(err.Error(), "disabled"),
"Should be denied by default")
})
t.Run("invalid addresses fail", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
assert.Error(t, err)
err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
assert.Error(t, err)
})
}
func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
if testing.Short() {
t.Skip("Skipping data transfer test in short mode")
}
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowLocalPortForwarding(true)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Port forwarding requires the actual current user, not test user
realUser, err := getRealCurrentUser()
require.NoError(t, err)
// Skip if running as system account that can't do port forwarding
if testutil.IsSystemAccount(realUser) {
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
}
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
InsecureSkipVerify: true, // Skip host key verification for test
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
if err := testServer.Close(); err != nil {
t.Logf("test server close error: %v", err)
}
}()
testServerAddr := testServer.Addr().String()
expectedResponse := "Hello, World!"
go func() {
for {
conn, err := testServer.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer func() {
if err := c.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
buf := make([]byte, 1024)
if _, err := c.Read(buf); err != nil {
t.Logf("connection read error: %v", err)
return
}
if _, err := c.Write([]byte(expectedResponse)); err != nil {
t.Logf("connection write error: %v", err)
}
}(conn)
}
}()
localListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
localAddr := localListener.Addr().String()
if err := localListener.Close(); err != nil {
t.Logf("local listener close error: %v", err)
}
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go func() {
err := client.LocalPortForward(ctx, localAddr, testServerAddr)
if err != nil && !errors.Is(err, context.Canceled) {
if isWindowsPrivilegeError(err) {
t.Logf("Port forward failed due to Windows privilege restrictions: %v", err)
} else {
t.Logf("Port forward error: %v", err)
}
}
}()
time.Sleep(100 * time.Millisecond)
conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
require.NoError(t, err)
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Logf("set read deadline error: %v", err)
}
response := make([]byte, len(expectedResponse))
n, err := io.ReadFull(conn, response)
require.NoError(t, err)
assert.Equal(t, len(expectedResponse), n)
assert.Equal(t, expectedResponse, string(response))
}
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
func getRealCurrentUser() (string, error) {
if runtime.GOOS == "windows" {
if currentUser, err := user.Current(); err == nil {
return currentUser.Username, nil
}
}
if username := os.Getenv("USER"); username != "" {
return username, nil
}
if currentUser, err := user.Current(); err == nil {
return currentUser.Username, nil
}
return "", fmt.Errorf("unable to determine current user")
}
// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions
func isWindowsPrivilegeError(err error) bool {
if err == nil {
return false
}
errStr := strings.ToLower(err.Error())
return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD
strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess)
strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser)
strings.Contains(errStr, "privilege") ||
strings.Contains(errStr, "access denied") ||
strings.Contains(errStr, "user authentication failed")
}

View File

@@ -0,0 +1,127 @@
//go:build !windows
package client
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
stdinFd := int(os.Stdin.Fd())
if !term.IsTerminal(stdinFd) {
return c.setupNonTerminalMode(ctx, session)
}
fd := int(os.Stdin.Fd())
state, err := term.MakeRaw(fd)
if err != nil {
return c.setupNonTerminalMode(ctx, session)
}
if err := c.setupTerminal(session, fd); err != nil {
if restoreErr := term.Restore(fd, state); restoreErr != nil {
log.Debugf("restore terminal state: %v", restoreErr)
}
return err
}
c.terminalState = state
c.terminalFd = fd
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go func() {
defer signal.Stop(sigChan)
select {
case <-ctx.Done():
if err := term.Restore(fd, state); err != nil {
log.Debugf("restore terminal state: %v", err)
}
case sig := <-sigChan:
if err := term.Restore(fd, state); err != nil {
log.Debugf("restore terminal state: %v", err)
}
signal.Reset(sig)
s, ok := sig.(syscall.Signal)
if !ok {
log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
return
}
if err := syscall.Kill(syscall.Getpid(), s); err != nil {
log.Debugf("kill process with signal %v: %v", s, err)
}
}
}()
return nil
}
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
return nil
}
// restoreWindowsConsoleState is a no-op on Unix systems
func (c *Client) restoreWindowsConsoleState() error {
return nil
}
func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
w, h, err := term.GetSize(fd)
if err != nil {
return fmt.Errorf("get terminal size: %w", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
// Ctrl+C
ssh.VINTR: 3,
// Ctrl+\
ssh.VQUIT: 28,
// Backspace
ssh.VERASE: 127,
// Ctrl+U
ssh.VKILL: 21,
// Ctrl+D
ssh.VEOF: 4,
ssh.VEOL: 0,
ssh.VEOL2: 0,
// Ctrl+Q
ssh.VSTART: 17,
// Ctrl+S
ssh.VSTOP: 19,
// Ctrl+Z
ssh.VSUSP: 26,
// Ctrl+O
ssh.VDISCARD: 15,
// Ctrl+R
ssh.VREPRINT: 18,
// Ctrl+W
ssh.VWERASE: 23,
// Ctrl+V
ssh.VLNEXT: 22,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("request pty: %w", err)
}
return nil
}

View File

@@ -0,0 +1,265 @@
package client
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
const (
enableProcessedInput = 0x0001
enableLineInput = 0x0002
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
enableVirtualTerminalInput = 0x0200
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
// ConsoleUnavailableError indicates that Windows console handles are not available
// (e.g., in CI environments where stdout/stdin are redirected)
type ConsoleUnavailableError struct {
Operation string
Err error
}
func (e *ConsoleUnavailableError) Error() string {
return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
}
func (e *ConsoleUnavailableError) Unwrap() error {
return e.Err
}
type coord struct {
x, y int16
}
type smallRect struct {
left, top, right, bottom int16
}
type consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes uint16
window smallRect
maximumWindowSize coord
}
func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
if err := c.saveWindowsConsoleState(); err != nil {
var consoleErr *ConsoleUnavailableError
if errors.As(err, &consoleErr) {
log.Debugf("console unavailable, not requesting PTY: %v", err)
return nil
}
return fmt.Errorf("save console state: %w", err)
}
if err := c.enableWindowsVirtualTerminal(); err != nil {
var consoleErr *ConsoleUnavailableError
if errors.As(err, &consoleErr) {
log.Debugf("virtual terminal unavailable: %v", err)
} else {
return fmt.Errorf("failed to enable virtual terminal: %w", err)
}
}
w, h := c.getWindowsConsoleSize()
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
ssh.ICRNL: 1,
ssh.OPOST: 1,
ssh.ONLCR: 1,
ssh.ISIG: 1,
ssh.ICANON: 1,
ssh.VINTR: 3, // Ctrl+C
ssh.VQUIT: 28, // Ctrl+\
ssh.VERASE: 127, // Backspace
ssh.VKILL: 21, // Ctrl+U
ssh.VEOF: 4, // Ctrl+D
ssh.VEOL: 0,
ssh.VEOL2: 0,
ssh.VSTART: 17, // Ctrl+Q
ssh.VSTOP: 19, // Ctrl+S
ssh.VSUSP: 26, // Ctrl+Z
ssh.VDISCARD: 15, // Ctrl+O
ssh.VWERASE: 23, // Ctrl+W
ssh.VLNEXT: 22, // Ctrl+V
ssh.VREPRINT: 18, // Ctrl+R
}
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
log.Debugf("restore Windows console state: %v", restoreErr)
}
return fmt.Errorf("request pty: %w", err)
}
return nil
}
func (c *Client) saveWindowsConsoleState() error {
defer func() {
if r := recover(); r != nil {
log.Debugf("panic in saveWindowsConsoleState: %v", r)
}
}()
stdout := syscall.Handle(os.Stdout.Fd())
stdin := syscall.Handle(os.Stdin.Fd())
var stdoutMode, stdinMode uint32
ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
if ret == 0 {
log.Debugf("failed to get stdout console mode: %v", err)
return &ConsoleUnavailableError{
Operation: "get stdout console mode",
Err: err,
}
}
ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
if ret == 0 {
log.Debugf("failed to get stdin console mode: %v", err)
return &ConsoleUnavailableError{
Operation: "get stdin console mode",
Err: err,
}
}
c.terminalFd = 1
c.windowsStdoutMode = stdoutMode
c.windowsStdinMode = stdinMode
log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
return nil
}
func (c *Client) enableWindowsVirtualTerminal() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
}
}()
stdout := syscall.Handle(os.Stdout.Fd())
stdin := syscall.Handle(os.Stdin.Fd())
var mode uint32
ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "get stdout console mode for VT",
Err: winErr,
}
}
mode |= enableVirtualTerminalProcessing
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "enable virtual terminal processing",
Err: winErr,
}
}
ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "get stdin console mode for VT",
Err: winErr,
}
}
mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
mode |= enableVirtualTerminalInput
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "set stdin raw mode",
Err: winErr,
}
}
log.Debugf("enabled Windows virtual terminal processing")
return nil
}
func (c *Client) getWindowsConsoleSize() (int, int) {
defer func() {
if r := recover(); r != nil {
log.Debugf("panic in getWindowsConsoleSize: %v", r)
}
}()
stdout := syscall.Handle(os.Stdout.Fd())
var csbi consoleScreenBufferInfo
ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
if ret == 0 {
log.Debugf("failed to get console buffer info, using defaults: %v", err)
return 80, 24
}
width := int(csbi.window.right - csbi.window.left + 1)
height := int(csbi.window.bottom - csbi.window.top + 1)
log.Debugf("Windows console size: %dx%d", width, height)
return width, height
}
func (c *Client) restoreWindowsConsoleState() error {
var err error
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
}
}()
if c.terminalFd != 1 {
return nil
}
stdout := syscall.Handle(os.Stdout.Fd())
stdin := syscall.Handle(os.Stdin.Fd())
ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
if ret == 0 {
log.Debugf("failed to restore stdout console mode: %v", winErr)
if err == nil {
err = fmt.Errorf("restore stdout console mode: %w", winErr)
}
}
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
if ret == 0 {
log.Debugf("failed to restore stdin console mode: %v", winErr)
if err == nil {
err = fmt.Errorf("restore stdin console mode: %w", winErr)
}
}
c.terminalFd = 0
c.windowsStdoutMode = 0
c.windowsStdinMode = 0
log.Debugf("restored Windows console state")
return err
}

171
client/ssh/common.go Normal file
View File

@@ -0,0 +1,171 @@
package ssh
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/proto"
)
const (
NetBirdSSHConfigFile = "99-netbird.conf"
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
WindowsSSHConfigDir = "ssh/ssh_config.d"
)
var (
// ErrPeerNotFound indicates the peer was not found in the network
ErrPeerNotFound = errors.New("peer not found in network")
// ErrNoStoredKey indicates the peer has no stored SSH host key
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
)
// HostKeyVerifier provides SSH host key verification
type HostKeyVerifier interface {
VerifySSHHostKey(peerAddress string, key []byte) error
}
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
type DaemonHostKeyVerifier struct {
client proto.DaemonServiceClient
}
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
return &DaemonHostKeyVerifier{
client: client,
}
}
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: peerAddress,
})
if err != nil {
return err
}
if !response.GetFound() {
return ErrPeerNotFound
}
storedKeyData := response.GetSshHostKey()
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
req := &proto.RequestJWTAuthRequest{}
if hint != "" {
req.Hint = &hint
}
authResponse, err := client.RequestJWTAuth(ctx, req)
if err != nil {
return "", fmt.Errorf("request JWT auth: %w", err)
}
if useCache && authResponse.CachedToken != "" {
log.Debug("Using cached authentication token")
return authResponse.CachedToken, nil
}
if stderr != nil {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
DeviceCode: authResponse.DeviceCode,
UserCode: authResponse.UserCode,
})
if err != nil {
return "", fmt.Errorf("wait for JWT token: %w", err)
}
if stdout != nil {
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
}
return tokenResponse.Token, nil
}
// VerifyHostKey verifies an SSH host key against stored peer key data.
// Returns nil only if the presented key matches the stored key.
// Returns ErrNoStoredKey if storedKeyData is empty.
// Returns an error if the keys don't match or if parsing fails.
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
if len(storedKeyData) == 0 {
return ErrNoStoredKey
}
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
}
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
}
return nil
}
// AddJWTAuth prepends JWT password authentication to existing auth methods.
// This ensures JWT auth is tried first while preserving any existing auth methods.
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
configWithJWT := *config
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
return &configWithJWT
}
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
// It tries multiple addresses (hostname, IP) for the peer before failing.
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
addresses := buildAddressList(hostname, remote)
presentedKey := key.Marshal()
for _, addr := range addresses {
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
if errors.Is(err, ErrPeerNotFound) {
// Try other addresses for this peer
continue
}
return err
}
// Verified
return nil
}
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
}
}
// buildAddressList creates a list of addresses to check for host key verification.
// It includes the original hostname and extracts the host part from the remote address if different.
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}

View File

@@ -0,0 +1,282 @@
package config
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const (
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
MaxPeersForSSHConfig = 200
fileWriteTimeout = 2 * time.Second
)
func isSSHConfigDisabled() bool {
value := os.Getenv(EnvDisableSSHConfig)
if value == "" {
return false
}
disabled, err := strconv.ParseBool(value)
if err != nil {
return true
}
return disabled
}
func isSSHConfigForced() bool {
value := os.Getenv(EnvForceSSHConfig)
if value == "" {
return false
}
forced, err := strconv.ParseBool(value)
if err != nil {
return true
}
return forced
}
// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count
func shouldGenerateSSHConfig(peerCount int) bool {
if isSSHConfigDisabled() {
return false
}
if isSSHConfigForced() {
return true
}
return peerCount <= MaxPeersForSSHConfig
}
// writeFileWithTimeout writes data to a file with a timeout
func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error {
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- os.WriteFile(filename, data, perm)
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
}
}
// Manager handles SSH client configuration for NetBird peers
type Manager struct {
sshConfigDir string
sshConfigFile string
}
// PeerSSHInfo represents a peer's SSH configuration information
type PeerSSHInfo struct {
Hostname string
IP string
FQDN string
}
// New creates a new SSH config manager
func New() *Manager {
sshConfigDir := getSystemSSHConfigDir()
return &Manager{
sshConfigDir: sshConfigDir,
sshConfigFile: nbssh.NetBirdSSHConfigFile,
}
}
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
func getSystemSSHConfigDir() string {
if runtime.GOOS == "windows" {
return getWindowsSSHConfigDir()
}
return nbssh.UnixSSHConfigDir
}
func getWindowsSSHConfigDir() string {
programData := os.Getenv("PROGRAMDATA")
if programData == "" {
programData = `C:\ProgramData`
}
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
}
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
if !shouldGenerateSSHConfig(len(peers)) {
m.logSkipReason(len(peers))
return nil
}
sshConfig, err := m.buildSSHConfig(peers)
if err != nil {
return fmt.Errorf("build SSH config: %w", err)
}
return m.writeSSHConfig(sshConfig)
}
func (m *Manager) logSkipReason(peerCount int) {
if isSSHConfigDisabled() {
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
} else {
log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.",
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
}
}
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
sshConfig := m.buildConfigHeader()
var allHostPatterns []string
for _, peer := range peers {
hostPatterns := m.buildHostPatterns(peer)
allHostPatterns = append(allHostPatterns, hostPatterns...)
}
if len(allHostPatterns) > 0 {
peerConfig, err := m.buildPeerConfig(allHostPatterns)
if err != nil {
return "", err
}
sshConfig += peerConfig
}
return sshConfig, nil
}
func (m *Manager) buildConfigHeader() string {
return "# NetBird SSH client configuration\n" +
"# Generated automatically - do not edit manually\n" +
"#\n" +
"# To disable SSH config management, use:\n" +
"# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" +
"#\n\n"
}
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
uniquePatterns := make(map[string]bool)
var deduplicatedPatterns []string
for _, pattern := range allHostPatterns {
if !uniquePatterns[pattern] {
uniquePatterns[pattern] = true
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
}
}
execPath, err := m.getNetBirdExecutablePath()
if err != nil {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
if runtime.GOOS == "windows" {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
} else {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
}
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
var hostPatterns []string
if peer.IP != "" {
hostPatterns = append(hostPatterns, peer.IP)
}
if peer.FQDN != "" {
hostPatterns = append(hostPatterns, peer.FQDN)
}
if peer.Hostname != "" && peer.Hostname != peer.FQDN {
hostPatterns = append(hostPatterns, peer.Hostname)
}
return hostPatterns
}
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil
}
// RemoveSSHClientConfig removes NetBird SSH configuration
func (m *Manager) RemoveSSHClientConfig() error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
err := os.Remove(sshConfigPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
}
if err == nil {
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
}
return nil
}
func (m *Manager) getNetBirdExecutablePath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("retrieve executable path: %w", err)
}
realPath, err := filepath.EvalSymlinks(execPath)
if err != nil {
log.Debugf("symlink resolution failed: %v", err)
return execPath, nil
}
return realPath, nil
}
// GetSSHConfigDir returns the SSH config directory path
func (m *Manager) GetSSHConfigDir() string {
return m.sshConfigDir
}
// GetSSHConfigFile returns the SSH config file name
func (m *Manager) GetSSHConfigFile() string {
return m.sshConfigFile
}

View File

@@ -0,0 +1,159 @@
package config
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestManager_SetupSSHClientConfig(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Test SSH config generation with peers
peers := []PeerSSHInfo{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Read generated config
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Verify the basic SSH config structure exists
assert.Contains(t, configStr, "# NetBird SSH client configuration")
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
// Check that peer hostnames are included
assert.Contains(t, configStr, "100.125.1.1")
assert.Contains(t, configStr, "100.125.1.2")
assert.Contains(t, configStr, "peer1.nb.internal")
assert.Contains(t, configStr, "peer2.nb.internal")
// Check platform-specific UserKnownHostsFile
if runtime.GOOS == "windows" {
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
} else {
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
}
}
func TestGetSystemSSHConfigDir(t *testing.T) {
configDir := getSystemSSHConfigDir()
// Path should not be empty
assert.NotEmpty(t, configDir)
// Should be an absolute path
assert.True(t, filepath.IsAbs(configDir))
// On Unix systems, should start with /etc
// On Windows, should contain ProgramData
if runtime.GOOS == "windows" {
assert.Contains(t, strings.ToLower(configDir), "programdata")
} else {
assert.Contains(t, configDir, "/etc/ssh")
}
}
func TestManager_PeerLimit(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Generate many peers (more than limit)
var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
})
}
// Test that SSH config generation is skipped when too many peers
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Config should not be created due to peer limit
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
_, err = os.Stat(configPath)
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable
t.Setenv(EnvForceSSHConfig, "true")
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Generate many peers (more than limit)
var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
})
}
// Test that SSH config generation is forced despite many peers
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Config should be created despite peer limit due to force flag
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
_, err = os.Stat(configPath)
require.NoError(t, err, "SSH config should be created when forced")
// Verify config contains peer hostnames
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
assert.Contains(t, configStr, "peer0.nb.internal")
assert.Contains(t, configStr, "peer1.nb.internal")
}

View File

@@ -0,0 +1,22 @@
package config
// ShutdownState represents SSH configuration state that needs to be cleaned up.
type ShutdownState struct {
SSHConfigDir string
SSHConfigFile string
}
// Name returns the state name for the state manager.
func (s *ShutdownState) Name() string {
return "ssh_config_state"
}
// Cleanup removes SSH client configuration files.
func (s *ShutdownState) Cleanup() error {
manager := &Manager{
sshConfigDir: s.SSHConfigDir,
sshConfigFile: s.SSHConfigFile,
}
return manager.RemoveSSHClientConfig()
}

View File

@@ -0,0 +1,99 @@
package detection
import (
"bufio"
"context"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
// ServerIdentifier is the base response for NetBird SSH servers
ServerIdentifier = "NetBird-SSH-Server"
// ProxyIdentifier is the base response for NetBird SSH proxy
ProxyIdentifier = "NetBird-SSH-Proxy"
// JWTRequiredMarker is appended to responses when JWT is required
JWTRequiredMarker = "NetBird-JWT-Required"
// Timeout is the timeout for SSH server detection
Timeout = 5 * time.Second
)
type ServerType string
const (
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
ServerTypeRegular ServerType = "regular"
)
// Dialer provides network connection capabilities
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// RequiresJWT checks if the server type requires JWT authentication
func (s ServerType) RequiresJWT() bool {
return s == ServerTypeNetBirdJWT
}
// ExitCode returns the exit code for the detect command
func (s ServerType) ExitCode() int {
switch s {
case ServerTypeNetBirdJWT:
return 0
case ServerTypeNetBirdNoJWT:
return 1
case ServerTypeRegular:
return 2
default:
return 2
}
}
// DetectSSHServerType detects SSH server type using the provided dialer
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
if err != nil {
log.Debugf("SSH connection failed for detection: %v", err)
return ServerTypeRegular, nil
}
defer conn.Close()
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
log.Debugf("set read deadline: %v", err)
return ServerTypeRegular, nil
}
reader := bufio.NewReader(conn)
serverBanner, err := reader.ReadString('\n')
if err != nil {
log.Debugf("read SSH banner: %v", err)
return ServerTypeRegular, nil
}
serverBanner = strings.TrimSpace(serverBanner)
log.Debugf("SSH server banner: %s", serverBanner)
if !strings.HasPrefix(serverBanner, "SSH-") {
log.Debugf("Invalid SSH banner")
return ServerTypeRegular, nil
}
if !strings.Contains(serverBanner, ServerIdentifier) {
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
return ServerTypeRegular, nil
}
if strings.Contains(serverBanner, JWTRequiredMarker) {
return ServerTypeNetBirdJWT, nil
}
return ServerTypeNetBirdNoJWT, nil
}

View File

@@ -1,53 +0,0 @@
//go:build !js
package ssh
import (
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"runtime"
"github.com/netbirdio/netbird/util"
)
func isRoot() bool {
return os.Geteuid() == 0
}
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
if !isRoot() {
shell := getUserShell(user)
if shell == "" {
shell = "/bin/sh"
}
return shell, []string{"-l"}, nil
}
loginPath, err = exec.LookPath("login")
if err != nil {
return "", nil, err
}
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
if err != nil {
return "", nil, err
}
switch runtime.GOOS {
case "linux":
if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
return loginPath, []string{"-f", user, "-p"}, nil
}
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
case "darwin":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
case "freebsd":
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
default:
return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}

View File

@@ -1,14 +0,0 @@
//go:build !darwin
// +build !darwin
package ssh
import "os/user"
func userNameLookup(username string) (*user.User, error) {
if username == "" || (username == "root" && !isRoot()) {
return user.Current()
}
return user.Lookup(username)
}

View File

@@ -1,51 +0,0 @@
//go:build darwin
// +build darwin
package ssh
import (
"bytes"
"fmt"
"os/exec"
"os/user"
"strings"
)
func userNameLookup(username string) (*user.User, error) {
if username == "" || (username == "root" && !isRoot()) {
return user.Current()
}
var userObject *user.User
userObject, err := user.Lookup(username)
if err != nil && err.Error() == user.UnknownUserError(username).Error() {
return idUserNameLookup(username)
} else if err != nil {
return nil, err
}
return userObject, nil
}
func idUserNameLookup(username string) (*user.User, error) {
cmd := exec.Command("id", "-P", username)
out, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err)
}
colon := ":"
if !bytes.Contains(out, []byte(username+colon)) {
return nil, fmt.Errorf("unable to find user in returned string")
}
// netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh
parts := strings.SplitN(string(out), colon, 10)
userObject := &user.User{
Username: parts[0],
Uid: parts[2],
Gid: parts[3],
Name: parts[7],
HomeDir: parts[8],
}
return userObject, nil
}

392
client/ssh/proxy/proxy.go Normal file
View File

@@ -0,0 +1,392 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/version"
)
const (
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
sshConnectionTimeout = 120 * time.Second
// sshHandshakeTimeout is the timeout for SSH handshake completion
sshHandshakeTimeout = 30 * time.Second
jwtAuthErrorMsg = "JWT authentication: %w"
)
type SSHProxy struct {
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("connect to daemon: %w", err)
}
return &SSHProxy{
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil
}
func (p *SSHProxy) Close() error {
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *SSHProxy) Connect(ctx context.Context) error {
hint := profilemanager.GetLoginHint()
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(s ssh.Session) {
p.sftpSubsystemHandler(s, jwtToken)
},
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": p.tcpipForwardHandler,
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
},
Version: serverVersion,
}
hostKey, err := generateHostKey()
if err != nil {
return fmt.Errorf("generate host key: %w", err)
}
sshServer.HostSigners = []ssh.Signer{hostKey}
conn := &stdioConn{
stdin: os.Stdin,
stdout: os.Stdout,
}
sshServer.HandleConn(conn)
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err)
}
go func() {
for win := range winCh {
if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
log.Debugf("window change: %v", err)
}
}
}()
}
if len(session.Command()) > 0 {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}
return
}
if err = serverSession.Shell(); err != nil {
log.Debugf("start shell: %v", err)
return
}
if err := serverSession.Wait(); err != nil {
log.Debugf("session wait: %v", err)
p.handleProxyExitCode(session, err)
}
}
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
}
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, fmt.Errorf("generate ED25519 key: %w", err)
}
signer, err := cryptossh.ParsePrivateKey(keyPEM)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return signer, nil
}
type stdioConn struct {
stdin io.Reader
stdout io.Writer
closed bool
mu sync.Mutex
}
func (c *stdioConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.EOF
}
c.mu.Unlock()
return c.stdin.Read(b)
}
func (c *stdioConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.ErrClosedPipe
}
c.mu.Unlock()
return c.stdout.Write(b)
}
func (c *stdioConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func (c *stdioConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
ctx, cancel := context.WithCancel(s.Context())
defer cancel()
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := sshClient.Close(); err != nil {
log.Debugf("close SSH client: %v", err)
}
}()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := serverSession.Close(); err != nil {
log.Debugf("close server session: %v", err)
}
}()
stdin, stdout, err := p.setupSFTPPipes(serverSession)
if err != nil {
log.Debugf("setup SFTP pipes: %v", err)
_ = s.Exit(1)
return
}
if err := serverSession.RequestSubsystem("sftp"); err != nil {
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
_ = s.Exit(1)
return
}
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
}
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
stdin, err := serverSession.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
}
stdout, err := serverSession.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
}
return stdin, stdout, nil
}
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
copyErrCh := make(chan error, 2)
go func() {
_, err := io.Copy(stdin, s)
if err != nil {
log.Debugf("SFTP client to server copy: %v", err)
}
if err := stdin.Close(); err != nil {
log.Debugf("close stdin: %v", err)
}
copyErrCh <- err
}()
go func() {
_, err := io.Copy(s, stdout)
if err != nil {
log.Debugf("SFTP server to client copy: %v", err)
}
copyErrCh <- err
}()
go func() {
<-ctx.Done()
if err := serverSession.Close(); err != nil {
log.Debugf("force close server session on context cancellation: %v", err)
}
}()
for i := 0; i < 2; i++ {
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
log.Debugf("SFTP copy error: %v", err)
}
}
if err := serverSession.Wait(); err != nil {
log.Debugf("SFTP session ended: %v", err)
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
config := &cryptossh.ClientConfig{
User: user,
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
Timeout: sshHandshakeTimeout,
HostKeyCallback: p.verifyHostKey,
}
dialer := &net.Dialer{
Timeout: sshConnectionTimeout,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to server: %w", err)
}
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("SSH handshake: %w", err)
}
return cryptossh.NewClient(clientConn, chans, reqs), nil
}
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}

View File

@@ -0,0 +1,367 @@
package proxy
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "ssh" {
if os.Args[2] == "exec" {
if len(os.Args) > 3 {
cmd := os.Args[3]
if cmd == "echo" && len(os.Args) > 4 {
fmt.Fprintln(os.Stdout, os.Args[4])
os.Exit(0)
}
}
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
os.Exit(1)
}
}
code := m.Run()
testutil.CleanupTestUsers()
os.Exit(code)
}
func TestSSHProxy_verifyHostKey(t *testing.T) {
t.Run("calls daemon to verify host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testPubKey, err := nbssh.GeneratePublicKey(testKey)
require.NoError(t, err)
mockDaemon.setHostKey("test-host", testPubKey)
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
assert.NoError(t, err)
})
t.Run("rejects unknown host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
require.NoError(t, err)
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
jwtToken string
}
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
key, found := m.hostKeys[req.PeerAddress]
return &proto.GetPeerSSHHostKeyResponse{
Found: found,
SshHostKey: key,
}, nil
}
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
return &proto.RequestJWTAuthResponse{
CachedToken: m.jwtToken,
}, nil
}
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
return &proto.WaitJWTTokenResponse{
Token: m.jwtToken,
}, nil
}
type mockDaemon struct {
addr string
server *grpc.Server
impl *mockDaemonServer
}
func startMockDaemon(t *testing.T) *mockDaemon {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
impl := &mockDaemonServer{
hostKeys: make(map[string][]byte),
jwtToken: "test-jwt-token",
}
grpcServer := grpc.NewServer()
proto.RegisterDaemonServiceServer(grpcServer, impl)
go func() {
_ = grpcServer.Serve(listener)
}()
return &mockDaemon{
addr: listener.Addr().String(),
server: grpcServer,
impl: impl,
}
}
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
}
}
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
t.Helper()
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -1,280 +0,0 @@
//go:build !js
package ssh
import (
"fmt"
"io"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"sync"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 44338
// TerminalTimeout is the timeout for terminal session to be ready
const TerminalTimeout = 10 * time.Second
// TerminalBackoffDelay is the delay between terminal session readiness checks
const TerminalBackoffDelay = 500 * time.Millisecond
// DefaultSSHServer is a function that creates DefaultServer
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return newDefaultServer(hostKeyPEM, addr)
}
// Server is an interface of SSH server
type Server interface {
// Stop stops SSH server.
Stop() error
// Start starts SSH server. Blocking
Start() error
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
RemoveAuthorizedKey(peer string)
// AddAuthorizedKey add a given peer key to server authorized keys
AddAuthorizedKey(peer, newKey string) error
}
// DefaultServer is the embedded NetBird SSH server
type DefaultServer struct {
listener net.Listener
// authorizedKeys is ssh pub key indexed by peer WireGuard public key
authorizedKeys map[string]ssh.PublicKey
mu sync.Mutex
hostKeyPEM []byte
sessions []ssh.Session
}
// newDefaultServer creates new server with provided host key
func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
allowedKeys := make(map[string]ssh.PublicKey)
return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil
}
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
func (srv *DefaultServer) RemoveAuthorizedKey(peer string) {
srv.mu.Lock()
defer srv.mu.Unlock()
delete(srv.authorizedKeys, peer)
}
// AddAuthorizedKey add a given peer key to server authorized keys
func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error {
srv.mu.Lock()
defer srv.mu.Unlock()
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
if err != nil {
return err
}
srv.authorizedKeys[peer] = parsedKey
return nil
}
// Stop stops SSH server.
func (srv *DefaultServer) Stop() error {
srv.mu.Lock()
defer srv.mu.Unlock()
err := srv.listener.Close()
if err != nil {
return err
}
for _, session := range srv.sessions {
err := session.Close()
if err != nil {
log.Warnf("failed closing SSH session from %v", err)
}
}
return nil
}
func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
for _, allowed := range srv.authorizedKeys {
if ssh.KeysEqual(allowed, key) {
return true
}
}
return false
}
func prepareUserEnv(user *user.User, shell string) []string {
return []string{
fmt.Sprint("SHELL=" + shell),
fmt.Sprint("USER=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir),
}
}
func acceptEnv(s string) bool {
split := strings.Split(s, "=")
if len(split) != 2 {
return false
}
return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
}
// sessionHandler handles SSH session post auth
func (srv *DefaultServer) sessionHandler(session ssh.Session) {
srv.mu.Lock()
srv.sessions = append(srv.sessions, session)
srv.mu.Unlock()
defer func() {
err := session.Close()
if err != nil {
return
}
}()
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
localUser, err := userNameLookup(session.User())
if err != nil {
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
err = session.Exit(1)
if err != nil {
return
}
log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
return
}
ptyReq, winCh, isPty := session.Pty()
if isPty {
loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
if err != nil {
log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
return
}
cmd := exec.Command(loginCmd, loginArgs...)
go func() {
<-session.Context().Done()
if cmd.Process == nil {
return
}
err := cmd.Process.Kill()
if err != nil {
log.Debugf("failed killing SSH process %v", err)
return
}
}()
cmd.Dir = localUser.HomeDir
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
for _, v := range session.Environ() {
if acceptEnv(v) {
cmd.Env = append(cmd.Env, v)
}
}
log.Debugf("Login command: %s", cmd.String())
file, err := pty.Start(cmd)
if err != nil {
log.Errorf("failed starting SSH server: %v", err)
}
go func() {
for win := range winCh {
setWinSize(file, win.Width, win.Height)
}
}()
srv.stdInOut(file, session)
err = cmd.Wait()
if err != nil {
return
}
} else {
_, err := io.WriteString(session, "only PTY is supported.\n")
if err != nil {
return
}
err = session.Exit(1)
if err != nil {
return
}
}
log.Debugf("SSH session ended")
}
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
go func() {
// stdin
_, err := io.Copy(file, session)
if err != nil {
_ = session.Exit(1)
return
}
}()
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
timer := time.NewTimer(TerminalTimeout)
for {
select {
case <-timer.C:
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
_ = session.Exit(1)
return
default:
// stdout
writtenBytes, err := io.Copy(session, file)
if err != nil && writtenBytes != 0 {
_ = session.Exit(0)
return
}
time.Sleep(TerminalBackoffDelay)
}
}
}
// Start starts SSH server. Blocking
func (srv *DefaultServer) Start() error {
log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())
publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler)
hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM)
err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM)
if err != nil {
return err
}
return nil
}
func getUserShell(userID string) string {
if runtime.GOOS == "linux" {
output, _ := exec.Command("getent", "passwd", userID).Output()
line := strings.SplitN(string(output), ":", 10)
if len(line) > 6 {
return strings.TrimSpace(line[6])
}
}
shell := os.Getenv("SHELL")
if shell == "" {
shell = "/bin/sh"
}
return shell
}

View File

@@ -0,0 +1,206 @@
package server
import (
"errors"
"fmt"
"io"
"os"
"os/exec"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// handleCommand executes an SSH command with privilege validation
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
hasPty := winCh != nil
commandType := "command"
if hasPty {
commandType = "Pty command"
}
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
if err != nil {
logger.Errorf("%s creation failed: %v", commandType, err)
errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
if hasPty {
errorMsg += " with Pty"
}
errorMsg += "\n"
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return
}
if !hasPty {
if s.executeCommand(logger, session, execCmd, cleanup) {
logger.Debugf("%s execution completed", commandType)
}
return
}
defer cleanup()
ptyReq, _, _ := session.Pty()
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
logger.Debugf("%s execution completed", commandType)
}
}
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, nil, errors.New("no user in privilege result")
}
// If PTY requested but su doesn't support --pty, skip su and use executor
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
if hasPty && !s.suSupportsPty {
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
// Try su first for system integration (PAM/audit) when privileged
cmd, err := s.createSuCommand(session, localUser, hasPty)
if err != nil || privilegeResult.UsedFallback {
log.Debugf("su command failed, falling back to executor: %v", err)
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
if err != nil {
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, cleanup, nil
}
cmd.Env = s.prepareCommandEnv(localUser, session)
return cmd, func() {}, nil
}
// executeCommand executes the command and handles I/O and exit codes
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
defer cleanup()
s.setupProcessGroup(execCmd)
stdinPipe, err := execCmd.StdinPipe()
if err != nil {
logger.Errorf("create stdin pipe: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
execCmd.Stdout = session
execCmd.Stderr = session.Stderr()
if execCmd.Dir != "" {
if _, err := os.Stat(execCmd.Dir); err != nil {
logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
execCmd.Dir = "/"
}
}
if err := execCmd.Start(); err != nil {
logger.Errorf("command start failed: %v", err)
// no user message for exec failure, just exit
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
go s.handleCommandIO(logger, stdinPipe, session)
return s.waitForCommandCleanup(logger, session, execCmd)
}
// handleCommandIO manages stdin/stdout copying in a goroutine
func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
defer func() {
if err := stdinPipe.Close(); err != nil {
logger.Debugf("stdin pipe close error: %v", err)
}
}()
if _, err := io.Copy(stdinPipe, session); err != nil {
logger.Debugf("stdin copy error: %v", err)
}
}
// waitForCommandCleanup waits for command completion with session disconnect handling
func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
ctx := session.Context()
done := make(chan error, 1)
go func() {
done <- execCmd.Wait()
}()
select {
case <-ctx.Done():
logger.Debugf("session cancelled, terminating command")
s.killProcessGroup(execCmd)
select {
case err := <-done:
logger.Tracef("command terminated after session cancellation: %v", err)
case <-time.After(5 * time.Second):
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
}
if err := session.Exit(130); err != nil {
logSessionExitError(logger, err)
}
return false
case err := <-done:
return s.handleCommandCompletion(logger, session, err)
}
}
// handleCommandCompletion handles command completion
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
if err != nil {
logger.Debugf("command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
return false
}
s.handleSessionExit(session, nil, logger)
return true
}
// handleSessionExit handles command errors and sets appropriate exit codes
func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
if err == nil {
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
return
}
var exitError *exec.ExitError
if errors.As(err, &exitError) {
if err := session.Exit(exitError.ExitCode()); err != nil {
logSessionExitError(logger, err)
}
} else {
logger.Debugf("non-exit error in command execution: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
}
}

View File

@@ -0,0 +1,52 @@
//go:build js
package server
import (
"context"
"errors"
"os/exec"
"os/user"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
// createSuCommand is not supported on JS/WASM
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
return nil, errNotSupported
}
// createExecutorCommand is not supported on JS/WASM
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
return nil, nil, errNotSupported
}
// prepareCommandEnv is not supported on JS/WASM
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
return nil
}
// setupProcessGroup is not supported on JS/WASM
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
}
// killProcessGroup is not supported on JS/WASM
func (s *Server) killProcessGroup(*exec.Cmd) {
}
// detectSuPtySupport always returns false on JS/WASM
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty is not supported on JS/WASM
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
logger.Errorf("PTY command execution not supported on JS/WASM")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}

View File

@@ -0,0 +1,329 @@
//go:build unix
package server
import (
"context"
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/user"
"strings"
"sync"
"syscall"
"time"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// ptyManager manages Pty file operations with thread safety
type ptyManager struct {
file *os.File
mu sync.RWMutex
closed bool
closeErr error
once sync.Once
}
func newPtyManager(file *os.File) *ptyManager {
return &ptyManager{file: file}
}
func (pm *ptyManager) Close() error {
pm.once.Do(func() {
pm.mu.Lock()
pm.closed = true
pm.closeErr = pm.file.Close()
pm.mu.Unlock()
})
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.closeErr
}
func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
pm.mu.RLock()
defer pm.mu.RUnlock()
if pm.closed {
return errors.New("pty is closed")
}
return pty.Setsize(pm.file, ws)
}
func (pm *ptyManager) File() *os.File {
return pm.file
}
// detectSuPtySupport checks if su supports the --pty flag
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()
cmd := exec.CommandContext(ctx, "su", "--help")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("su --help failed (may not support --help): %v", err)
return false
}
supported := strings.Contains(string(output), "--pty")
log.Debugf("su --pty support detected: %v", supported)
return supported
}
// createSuCommand creates a command using su -l -c for privilege switching
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
suPath, err := exec.LookPath("su")
if err != nil {
return nil, fmt.Errorf("su command not available: %w", err)
}
command := session.RawCommand()
if command == "" {
return nil, fmt.Errorf("no command specified for su execution")
}
args := []string{"-l"}
if hasPty && s.suSupportsPty {
args = append(args, "--pty")
}
args = append(args, localUser.Username, "-c", command)
cmd := exec.CommandContext(session.Context(), suPath, args...)
cmd.Dir = localUser.HomeDir
return cmd, nil
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-l"}
}
return []string{shell, "-l", "-c", cmdString}
}
// prepareCommandEnv prepares environment variables for command execution on Unix
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
// executeCommandWithPty executes a command with PTY allocation
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
termType := ptyReq.Term
if termType == "" {
termType = "xterm-256color"
}
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
if err != nil {
logger.Errorf("Pty command creation failed: %v", err)
errorMsg := "User switching failed - login command not available\r\n"
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
logger.Infof("starting interactive shell: %s", execCmd.Path)
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
}
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
if err != nil {
logger.Errorf("Pty start failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
ptyMgr := newPtyManager(ptmx)
defer func() {
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close error: %v", err)
}
}()
go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
s.handlePtyIO(logger, session, ptyMgr)
s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
return true
}
func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
winSize := &pty.Winsize{
Cols: uint16(ptyReq.Window.Width),
Rows: uint16(ptyReq.Window.Height),
}
if winSize.Cols == 0 {
winSize.Cols = 80
}
if winSize.Rows == 0 {
winSize.Rows = 24
}
ptmx, err := pty.StartWithSize(execCmd, winSize)
if err != nil {
return nil, fmt.Errorf("start Pty: %w", err)
}
return ptmx, nil
}
func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
for {
select {
case <-session.Context().Done():
return
case win, ok := <-winCh:
if !ok {
return
}
if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
}
}
}
}
func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
ptmx := ptyMgr.File()
go func() {
if _, err := io.Copy(ptmx, session); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty input copy error: %v", err)
}
}
}()
go func() {
defer func() {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Debugf("session close error: %v", err)
}
}()
if _, err := io.Copy(session, ptmx); err != nil {
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
logger.Warnf("Pty output copy error: %v", err)
}
}
}()
}
func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
ctx := session.Context()
done := make(chan error, 1)
go func() {
done <- execCmd.Wait()
}()
select {
case <-ctx.Done():
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
case err := <-done:
s.handlePtyCommandCompletion(logger, session, err)
}
}
func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
logger.Debugf("Pty session cancelled, terminating command")
if err := ptyMgr.Close(); err != nil {
logger.Debugf("Pty close during session cancellation: %v", err)
}
s.killProcessGroup(execCmd)
select {
case err := <-done:
if err != nil {
logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
} else {
logger.Debugf("Pty command terminated after session cancellation")
}
case <-time.After(5 * time.Second):
logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
}
if err := session.Exit(130); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
if err != nil {
logger.Debugf("Pty command execution failed: %v", err)
s.handleSessionExit(session, err, logger)
return
}
// Normal completion
logger.Debugf("Pty command completed successfully")
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
}
func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
}
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
if cmd.Process == nil {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
pgid := cmd.Process.Pid
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
logger.Debugf("kill process group SIGTERM: %v", err)
return
}
const gracePeriod = 500 * time.Millisecond
const checkInterval = 50 * time.Millisecond
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
timeout := time.After(gracePeriod)
for {
select {
case <-timeout:
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
logger.Debugf("kill process group SIGKILL: %v", err)
}
return
case <-ticker.C:
if err := syscall.Kill(-pgid, 0); err != nil {
return
}
}
}
}

View File

@@ -0,0 +1,430 @@
package server
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"strings"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/ssh/server/winpty"
)
// getUserEnvironment retrieves the Windows environment for the target user.
// Follows OpenSSH's resilient approach with graceful degradation on failures.
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
userToken, err := s.getUserToken(username, domain)
if err != nil {
return nil, fmt.Errorf("get user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err)
}
}()
return s.getUserEnvironmentWithToken(userToken, username, domain)
}
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
userProfile, err := s.loadUserProfile(userToken, username, domain)
if err != nil {
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
}
envMap := make(map[string]string)
if err := s.loadSystemEnvironment(envMap); err != nil {
log.Debugf("failed to load system environment from registry: %v", err)
}
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
var env []string
for key, value := range envMap {
env = append(env, key+"="+value)
}
return env, nil
}
// getUserToken creates a user token for the specified user.
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
privilegeDropper := NewPrivilegeDropper()
token, err := privilegeDropper.createToken(username, domain)
if err != nil {
return 0, fmt.Errorf("generate S4U user token: %w", err)
}
return token, nil
}
// loadUserProfile loads the Windows user profile and returns the profile path.
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
usernamePtr, err := windows.UTF16PtrFromString(username)
if err != nil {
return "", fmt.Errorf("convert username to UTF-16: %w", err)
}
var domainUTF16 *uint16
if domain != "" && domain != "." {
domainUTF16, err = windows.UTF16PtrFromString(domain)
if err != nil {
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
}
}
type profileInfo struct {
dwSize uint32
dwFlags uint32
lpUserName *uint16
lpProfilePath *uint16
lpDefaultPath *uint16
lpServerName *uint16
lpPolicyPath *uint16
hProfile windows.Handle
}
const PI_NOUI = 0x00000001
profile := profileInfo{
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
dwFlags: PI_NOUI,
lpUserName: usernamePtr,
lpServerName: domainUTF16,
}
userenv := windows.NewLazySystemDLL("userenv.dll")
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
ret, _, err := loadUserProfileW.Call(
uintptr(userToken),
uintptr(unsafe.Pointer(&profile)),
)
if ret == 0 {
return "", fmt.Errorf("LoadUserProfileW: %w", err)
}
if profile.lpProfilePath == nil {
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
}
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
return profilePath, nil
}
// loadSystemEnvironment loads system-wide environment variables from registry.
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
registry.QUERY_VALUE)
if err != nil {
return fmt.Errorf("open system environment registry key: %w", err)
}
defer func() {
if err := key.Close(); err != nil {
log.Debugf("close registry key: %v", err)
}
}()
return s.readRegistryEnvironment(key, envMap)
}
// readRegistryEnvironment reads environment variables from a registry key.
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
names, err := key.ReadValueNames(0)
if err != nil {
return fmt.Errorf("read registry value names: %w", err)
}
for _, name := range names {
value, valueType, err := key.GetStringValue(name)
if err != nil {
log.Debugf("failed to read registry value %s: %v", name, err)
continue
}
finalValue := s.expandRegistryValue(value, valueType, name)
s.setEnvironmentVariable(envMap, name, finalValue)
}
return nil
}
// expandRegistryValue expands registry values if they contain environment variables.
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
if valueType != registry.EXPAND_SZ {
return value
}
sourcePtr := windows.StringToUTF16Ptr(value)
expandedBuffer := make([]uint16, 1024)
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
if err != nil {
log.Debugf("failed to expand environment string for %s: %v", name, err)
return value
}
// If buffer was too small, retry with larger buffer
if expandedLen > uint32(len(expandedBuffer)) {
expandedBuffer = make([]uint16, expandedLen)
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
if err != nil {
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
return value
}
}
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
}
return value
}
// setEnvironmentVariable sets an environment variable with special handling for PATH.
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
upperName := strings.ToUpper(name)
if upperName == "PATH" {
if existing, exists := envMap["PATH"]; exists && existing != value {
envMap["PATH"] = existing + ";" + value
} else {
envMap["PATH"] = value
}
} else {
envMap[upperName] = value
}
}
// setUserEnvironmentVariables sets critical user-specific environment variables.
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
envMap["USERPROFILE"] = userProfile
if len(userProfile) >= 2 && userProfile[1] == ':' {
envMap["HOMEDRIVE"] = userProfile[:2]
envMap["HOMEPATH"] = userProfile[2:]
}
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
envMap["TEMP"] = tempDir
envMap["TMP"] = tempDir
envMap["USERNAME"] = username
if domain != "" && domain != "." {
envMap["USERDOMAIN"] = domain
envMap["USERDNSDOMAIN"] = domain
}
systemVars := []string{
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
}
for _, sysVar := range systemVars {
if sysValue := os.Getenv(sysVar); sysValue != "" {
envMap[sysVar] = sysValue
}
}
}
// prepareCommandEnv prepares environment variables for command execution on Windows
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
username, domain := s.parseUsername(localUser.Username)
userEnv, err := s.getUserEnvironment(username, domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
env := userEnv
env = append(env, prepareSSHEnv(session)...)
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
if privilegeResult.User == nil {
logger.Errorf("no user in privilege result")
return false
}
cmd := session.Command()
shell := getUserShell(privilegeResult.User.Uid)
if len(cmd) == 0 {
logger.Infof("starting interactive shell: %s", shell)
} else {
logger.Infof("executing command: %s", safeLogCommand(cmd))
}
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
return true
}
// getShellCommandArgs returns the shell command and arguments for executing a command string
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
if cmdString == "" {
return []string{shell, "-NoLogo"}
}
return []string{shell, "-Command", cmdString}
}
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
logger.Info("starting interactive shell")
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
}
type PtyExecutionRequest struct {
Shell string
Command string
Width int
Height int
Username string
Domain string
}
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
privilegeDropper := NewPrivilegeDropper()
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
if err != nil {
return fmt.Errorf("create user token: %w", err)
}
defer func() {
if err := windows.CloseHandle(userToken); err != nil {
log.Debugf("close user token: %v", err)
}
}()
server := &Server{}
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
if err != nil {
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
userEnv = os.Environ()
}
workingDir := getUserHomeFromEnv(userEnv)
if workingDir == "" {
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
}
ptyConfig := winpty.PtyConfig{
Shell: req.Shell,
Command: req.Command,
Width: req.Width,
Height: req.Height,
WorkingDir: workingDir,
}
userConfig := winpty.UserConfig{
Token: userToken,
Environment: userEnv,
}
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
}
func getUserHomeFromEnv(env []string) string {
for _, envVar := range env {
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
return envVar[12:]
}
}
return ""
}
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
// Windows doesn't support process groups in the same way as Unix
// Process creation groups are handled differently
}
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
if cmd.Process == nil {
return
}
logger := log.WithField("pid", cmd.Process.Pid)
if err := cmd.Process.Kill(); err != nil {
logger.Debugf("kill process failed: %v", err)
}
}
// detectSuPtySupport always returns false on Windows as su is not available
func (s *Server) detectSuPtySupport(context.Context) bool {
return false
}
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
command := session.RawCommand()
if command == "" {
logger.Error("no command specified for PTY execution")
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
}
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
localUser := privilegeResult.User
if localUser == nil {
logger.Errorf("no user in privilege result")
return false
}
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
req := PtyExecutionRequest{
Shell: shell,
Command: command,
Width: ptyReq.Window.Width,
Height: ptyReq.Window.Height,
Username: username,
Domain: domain,
}
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
logger.Errorf("ConPty execution failed: %v", err)
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}
logger.Debug("ConPty execution completed")
return true
}

View File

@@ -0,0 +1,722 @@
package server
import (
"context"
"crypto/ed25519"
"crypto/rand"
"fmt"
"io"
"net"
"os"
"os/exec"
"runtime"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
testutil.CleanupTestUsers()
os.Exit(code)
}
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
func TestSSHServerCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH compatibility tests in short mode")
}
// Check if ssh binary is available
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Set up SSH server - use our existing key generation for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate OpenSSH-compatible keys for client
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Create temporary key files for SSH client
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
defer cleanupKey()
// Extract host and port from server address
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
// Get appropriate user for SSH connection (handle system accounts)
username := testutil.GetTestUsername(t)
t.Run("basic command execution", func(t *testing.T) {
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
})
t.Run("interactive command", func(t *testing.T) {
testSSHInteractiveCommand(t, host, portStr, clientKeyFile)
})
t.Run("port forwarding", func(t *testing.T) {
testSSHPortForwarding(t, host, portStr, clientKeyFile)
})
}
// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user.
func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) {
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"echo", "hello_world")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("SSH command failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
}
// testSSHInteractiveCommand tests interactive shell session.
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host))
stdin, err := cmd.StdinPipe()
if err != nil {
t.Skipf("Cannot create stdin pipe: %v", err)
return
}
stdout, err := cmd.StdoutPipe()
if err != nil {
t.Skipf("Cannot create stdout pipe: %v", err)
return
}
err = cmd.Start()
if err != nil {
t.Logf("Cannot start SSH session: %v", err)
return
}
go func() {
defer func() {
if err := stdin.Close(); err != nil {
t.Logf("stdin close error: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil {
t.Logf("stdin write error: %v", err)
}
time.Sleep(100 * time.Millisecond)
if _, err := stdin.Write([]byte("exit\n")); err != nil {
t.Logf("stdin write error: %v", err)
}
}()
output, err := io.ReadAll(stdout)
if err != nil {
t.Logf("Cannot read SSH output: %v", err)
}
err = cmd.Wait()
if err != nil {
t.Logf("SSH interactive session error: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work")
}
// testSSHPortForwarding tests port forwarding compatibility.
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer testServer.Close()
testServerAddr := testServer.Addr().String()
expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK"
go func() {
for {
conn, err := testServer.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer func() {
if err := c.Close(); err != nil {
t.Logf("test server connection close error: %v", err)
}
}()
buf := make([]byte, 1024)
if _, err := c.Read(buf); err != nil {
t.Logf("Test server read error: %v", err)
}
if _, err := c.Write([]byte(expectedResponse)); err != nil {
t.Logf("Test server write error: %v", err)
}
}(conn)
}
}()
localListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
localAddr := localListener.Addr().String()
localListener.Close()
_, localPort, err := net.SplitHostPort(localAddr)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr)
cmd := exec.CommandContext(ctx, "ssh",
"-i", keyFile,
"-p", port,
"-L", forwardSpec,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-N",
fmt.Sprintf("%s@%s", username, host))
err = cmd.Start()
if err != nil {
t.Logf("Cannot start SSH port forwarding: %v", err)
return
}
defer func() {
if cmd.Process != nil {
if err := cmd.Process.Kill(); err != nil {
t.Logf("process kill error: %v", err)
}
}
if err := cmd.Wait(); err != nil {
t.Logf("process wait after kill: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
if err != nil {
t.Logf("Cannot connect to forwarded port: %v", err)
return
}
defer func() {
if err := conn.Close(); err != nil {
t.Logf("forwarded connection close error: %v", err)
}
}()
request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
_, err = conn.Write([]byte(request))
require.NoError(t, err)
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
log.Debugf("failed to set read deadline: %v", err)
}
response := make([]byte, len(expectedResponse))
n, err := io.ReadFull(conn, response)
if err != nil {
t.Logf("Cannot read forwarded response: %v", err)
return
}
assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes")
assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding")
}
// isSSHClientAvailable checks if the ssh binary is available
func isSSHClientAvailable() bool {
_, err := exec.LookPath("ssh")
return err == nil
}
// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use.
func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) {
// Check if ssh-keygen is available
if _, err := exec.LookPath("ssh-keygen"); err != nil {
// Fall back to our existing key generation and try to convert
return generateOpenSSHKeyFallback()
}
// Create temporary file for ssh-keygen
tempFile, err := os.CreateTemp("", "ssh_keygen_*")
if err != nil {
return nil, nil, fmt.Errorf("create temp file: %w", err)
}
keyPath := tempFile.Name()
tempFile.Close()
// Remove the temp file so ssh-keygen can create it
if err := os.Remove(keyPath); err != nil {
t.Logf("failed to remove key file: %v", err)
}
// Clean up temp files
defer func() {
if err := os.Remove(keyPath); err != nil {
t.Logf("failed to cleanup key file: %v", err)
}
if err := os.Remove(keyPath + ".pub"); err != nil {
t.Logf("failed to cleanup public key file: %v", err)
}
}()
// Generate key using ssh-keygen
cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output))
}
// Read private key
privKeyBytes, err := os.ReadFile(keyPath)
if err != nil {
return nil, nil, fmt.Errorf("read private key: %w", err)
}
// Read public key
pubKeyBytes, err := os.ReadFile(keyPath + ".pub")
if err != nil {
return nil, nil, fmt.Errorf("read public key: %w", err)
}
return privKeyBytes, pubKeyBytes, nil
}
// generateOpenSSHKeyFallback falls back to generating keys using our existing method
func generateOpenSSHKeyFallback() ([]byte, []byte, error) {
// Generate shared.ED25519 key pair using our existing method
_, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, fmt.Errorf("generate key: %w", err)
}
// Convert to SSH format
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
if err != nil {
return nil, nil, fmt.Errorf("create signer: %w", err)
}
// For the fallback, just use our PKCS#8 format and hope it works
// This won't be in OpenSSH format but might still work with some SSH clients
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, nil, fmt.Errorf("generate fallback key: %w", err)
}
// Get public key in SSH format
sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey())
return hostKey, sshPubKey, nil
}
// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes
func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) {
t.Helper()
tempFile, err := os.CreateTemp("", "ssh_test_key_*")
require.NoError(t, err)
_, err = tempFile.Write(keyBytes)
require.NoError(t, err)
err = tempFile.Close()
require.NoError(t, err)
// Set proper permissions for SSH key (readable by owner only)
err = os.Chmod(tempFile.Name(), 0600)
require.NoError(t, err)
cleanup := func() {
_ = os.Remove(tempFile.Name())
}
return tempFile.Name(), cleanup
}
// createTempKeyFile creates a temporary SSH private key file (for backward compatibility)
func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
return createTempKeyFileFromBytes(t, privateKey)
}
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
func TestSSHServerFeatureCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH feature compatibility tests in short mode")
}
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Test various SSH features
testCases := []struct {
name string
testFunc func(t *testing.T, host, port, keyFile string)
description string
}{
{
name: "command_with_flags",
testFunc: testCommandWithFlags,
description: "Commands with flags should work like standard SSH",
},
{
name: "environment_variables",
testFunc: testEnvironmentVariables,
description: "Environment variables should be available",
},
{
name: "exit_codes",
testFunc: testExitCodes,
description: "Exit codes should be properly handled",
},
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.testFunc(t, host, portStr, clientKeyFile)
})
}
}
// testCommandWithFlags tests that commands with flags work properly
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Test ls with flags
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"ls", "-la", "/tmp")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Command with flags failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
// Should not be empty and should not contain error messages
assert.NotEmpty(t, string(output), "ls -la should produce output")
assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed")
}
// testEnvironmentVariables tests that environment is properly set up
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"echo", "$HOME")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Environment test failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
// HOME environment variable should be available
homeOutput := strings.TrimSpace(string(output))
assert.NotEmpty(t, homeOutput, "HOME environment variable should be set")
assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded")
}
// testExitCodes tests that exit codes are properly handled
func testExitCodes(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Test successful command (exit code 0)
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"true") // always succeeds
err := cmd.Run()
assert.NoError(t, err, "Command with exit code 0 should succeed")
// Test failing command (exit code 1)
cmd = exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
"false") // always fails
err = cmd.Run()
assert.Error(t, err, "Command with exit code 1 should fail")
// Check if it's the right kind of error
if exitError, ok := err.(*exec.ExitError); ok {
assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved")
}
}
// TestSSHServerSecurityFeatures tests security-related SSH features
func TestSSHServerSecurityFeatures(t *testing.T) {
if testing.Short() {
t.Skip("Skipping SSH security tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Set up SSH server with specific security settings
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
t.Run("key_authentication", func(t *testing.T) {
// Test that key authentication works
cmd := exec.Command("ssh",
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "PasswordAuthentication=no",
fmt.Sprintf("%s@%s", username, host),
"echo", "auth_success")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Key authentication failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "auth_success", "Key authentication should work")
})
t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) {
// Create a different key that shouldn't be accepted
wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey)
defer cleanupWrongKey()
// Test that wrong key is rejected
cmd := exec.Command("ssh",
"-i", wrongKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
"-o", "PasswordAuthentication=no",
fmt.Sprintf("%s@%s", username, host),
"echo", "should_not_work")
err = cmd.Run()
assert.NoError(t, err, "Any key should work in no-auth mode")
})
}
// TestCrossPlatformCompatibility tests cross-platform behavior
func TestCrossPlatformCompatibility(t *testing.T) {
if testing.Short() {
t.Skip("Skipping cross-platform compatibility tests in short mode")
}
if !isSSHClientAvailable() {
t.Skip("SSH client not available on this system")
}
// Get appropriate user for SSH connection
username := testutil.GetTestUsername(t)
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
defer cleanupKey()
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
// Test platform-specific commands
var testCommand string
switch runtime.GOOS {
case "windows":
testCommand = "echo %OS%"
default:
testCommand = "uname"
}
cmd := exec.Command("ssh",
"-i", clientKeyFile,
"-p", portStr,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("%s@%s", username, host),
testCommand)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Platform-specific command failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
outputStr := strings.TrimSpace(string(output))
t.Logf("Platform command output: %s", outputStr)
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
}

View File

@@ -0,0 +1,253 @@
//go:build unix
package server
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"runtime"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
)
// Exit codes for executor process communication
const (
ExitCodeSuccess = 0
ExitCodePrivilegeDropFail = 10
ExitCodeShellExecFail = 11
ExitCodeValidationFail = 12
)
// ExecutorConfig holds configuration for the executor process
type ExecutorConfig struct {
UID uint32
GID uint32
Groups []uint32
WorkingDir string
Shell string
Command string
PTY bool
}
// PrivilegeDropper handles secure privilege dropping in child processes
type PrivilegeDropper struct{}
// NewPrivilegeDropper creates a new privilege dropper
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) {
netbirdPath, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("get netbird executable path: %w", err)
}
if err := pd.validatePrivileges(config.UID, config.GID); err != nil {
return nil, fmt.Errorf("invalid privileges: %w", err)
}
args := []string{
"ssh", "exec",
"--uid", fmt.Sprintf("%d", config.UID),
"--gid", fmt.Sprintf("%d", config.GID),
"--working-dir", config.WorkingDir,
"--shell", config.Shell,
}
for _, group := range config.Groups {
args = append(args, "--groups", fmt.Sprintf("%d", group))
}
if config.PTY {
args = append(args, "--pty")
}
if config.Command != "" {
args = append(args, "--cmd", config.Command)
}
// Log executor args safely - show all args except hide the command value
safeArgs := make([]string, len(args))
copy(safeArgs, args)
for i := 0; i < len(safeArgs)-1; i++ {
if safeArgs[i] == "--cmd" {
cmdParts := strings.Fields(safeArgs[i+1])
safeArgs[i+1] = safeLogCommand(cmdParts)
break
}
}
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
return exec.CommandContext(ctx, netbirdPath, args...), nil
}
// DropPrivileges performs privilege dropping with thread locking for security
func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
if err := pd.validatePrivileges(targetUID, targetGID); err != nil {
return fmt.Errorf("invalid privileges: %w", err)
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
originalUID := os.Geteuid()
originalGID := os.Getegid()
if originalUID != int(targetUID) || originalGID != int(targetGID) {
if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil {
return fmt.Errorf("set groups and IDs: %w", err)
}
}
if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil {
return err
}
log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID)
return nil
}
// setGroupsAndIDs sets the supplementary groups, GID, and UID
func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
groups := make([]int, len(supplementaryGroups))
for i, g := range supplementaryGroups {
groups[i] = int(g)
}
if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" {
if len(groups) == 0 || groups[0] != int(targetGID) {
groups = append([]int{int(targetGID)}, groups...)
}
}
if err := syscall.Setgroups(groups); err != nil {
return fmt.Errorf("setgroups to %v: %w", groups, err)
}
if err := syscall.Setgid(int(targetGID)); err != nil {
return fmt.Errorf("setgid to %d: %w", targetGID, err)
}
if err := syscall.Setuid(int(targetUID)); err != nil {
return fmt.Errorf("setuid to %d: %w", targetUID, err)
}
return nil
}
// validatePrivilegeDropSuccess validates that privilege dropping was successful
func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error {
if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil {
return err
}
if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil {
return err
}
return nil
}
// validatePrivilegeDropReversibility ensures privileges cannot be restored
func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error {
if originalGID != int(targetGID) {
if err := syscall.Setegid(originalGID); err == nil {
return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID)
}
}
if originalUID != int(targetUID) {
if err := syscall.Seteuid(originalUID); err == nil {
return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID)
}
}
return nil
}
// validateCurrentPrivileges validates the current UID and GID match the target
func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error {
currentUID := os.Geteuid()
if currentUID != int(targetUID) {
return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID)
}
currentGID := os.Getegid()
if currentGID != int(targetGID) {
return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID)
}
return nil
}
// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures
func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) {
log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
// TODO: Implement Pty support for executor path
if config.PTY {
config.PTY = false
}
if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err)
os.Exit(ExitCodePrivilegeDropFail)
}
if config.WorkingDir != "" {
if err := os.Chdir(config.WorkingDir); err != nil {
log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err)
}
}
var execCmd *exec.Cmd
if config.Command == "" {
os.Exit(ExitCodeSuccess)
}
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
execCmd.Stdin = os.Stdin
execCmd.Stdout = os.Stdout
execCmd.Stderr = os.Stderr
cmdParts := strings.Fields(config.Command)
safeCmd := safeLogCommand(cmdParts)
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
if err := execCmd.Run(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
// Normal command exit with non-zero code - not an SSH execution error
log.Tracef("command exited with code %d", exitError.ExitCode())
os.Exit(exitError.ExitCode())
}
// Actual execution failure (command not found, permission denied, etc.)
log.Debugf("command execution failed: %v", err)
os.Exit(ExitCodeShellExecFail)
}
os.Exit(ExitCodeSuccess)
}
// validatePrivileges validates that privilege dropping to the target UID/GID is allowed
func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error {
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
// Allow same-user operations (no privilege dropping needed)
if uid == currentUID && gid == currentGID {
return nil
}
// Only root can drop privileges to other users
if currentUID != 0 {
return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid)
}
// Root can drop to any user (including root itself)
return nil
}

View File

@@ -0,0 +1,262 @@
//go:build unix
package server
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
pd := NewPrivilegeDropper()
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
tests := []struct {
name string
uid uint32
gid uint32
wantErr bool
}{
{
name: "same user - no privilege drop needed",
uid: currentUID,
gid: currentGID,
wantErr: false,
},
{
name: "non-root to different user should fail",
uid: currentUID + 1, // Use a different UID to ensure it's actually different
gid: currentGID + 1, // Use a different GID to ensure it's actually different
wantErr: currentUID != 0, // Only fail if current user is not root
},
{
name: "root can drop to any user",
uid: 1000,
gid: 1000,
wantErr: false,
},
{
name: "root can stay as root",
uid: 0,
gid: 0,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Skip non-root tests when running as root, and root tests when not root
if tt.name == "non-root to different user should fail" && currentUID == 0 {
t.Skip("Skipping non-root test when running as root")
}
if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 {
t.Skip("Skipping root test when not running as root")
}
err := pd.validatePrivileges(tt.uid, tt.gid)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
// This test requires root privileges and will be skipped if not running as root
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("This test requires root privileges")
}
// Find a non-root user to test with
testUser, err := findNonRootUser()
if err != nil {
t.Skip("No suitable non-root user found for testing")
}
// Verify the user actually exists by looking it up again
_, err = user.LookupId(testUser.Uid)
if err != nil {
t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err)
}
uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
require.NoError(t, err)
targetUID := uint32(uid64)
gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
require.NoError(t, err)
targetGID := uint32(gid64)
// Test in a child process to avoid affecting the test runner
if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
pd := NewPrivilegeDropper()
// This should succeed
err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
require.NoError(t, err)
// Verify we are now running as the target user
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
assert.Equal(t, targetUID, currentUID, "UID should match target")
assert.Equal(t, targetGID, currentGID, "GID should match target")
assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
return
}
// Fork a child process to test privilege dropping
cmd := os.Args[0]
args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
execCmd := exec.Command(cmd, args...)
execCmd.Env = env
err = execCmd.Run()
require.NoError(t, err, "Child process should succeed")
}
// findNonRootUser finds any non-root user on the system for testing
func findNonRootUser() (*user.User, error) {
// Try common non-root users, but avoid "nobody" on macOS due to negative UID issues
commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
for _, username := range commonUsers {
if u, err := user.Lookup(username); err == nil {
// Parse as signed integer first to handle negative UIDs
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
if err != nil {
continue
}
// Skip negative UIDs (like nobody=-2 on macOS) and root
if uid64 > 0 && uid64 != 0 {
return u, nil
}
}
}
// If no common users found, try to find any regular user with UID > 100
// This helps on macOS where regular users start at UID 501
allUsers := []string{"vma", "user", "test", "admin"}
for _, username := range allUsers {
if u, err := user.Lookup(username); err == nil {
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
if err != nil {
continue
}
if uid64 > 100 { // Regular user
return u, nil
}
}
}
// If no common users found, return an error
return nil, fmt.Errorf("no suitable non-root user found on this system")
}
func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
pd := NewPrivilegeDropper()
currentUID := uint32(os.Geteuid())
if currentUID == 0 {
// When running as root, test that root can create commands for any user
config := ExecutorConfig{
UID: 1000, // Target non-root user
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/tmp",
Shell: "/bin/sh",
Command: "echo test",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
assert.NoError(t, err, "Root should be able to create commands for any user")
assert.NotNil(t, cmd)
} else {
// When running as non-root, test that we can't drop to a different user
config := ExecutorConfig{
UID: 0, // Try to target root
GID: 0,
Groups: []uint32{0},
WorkingDir: "/tmp",
Shell: "/bin/sh",
Command: "echo test",
}
_, err := pd.CreateExecutorCommand(context.Background(), config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot drop privileges")
}
}

View File

@@ -0,0 +1,570 @@
//go:build windows
package server
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"strings"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
ExitCodeSuccess = 0
ExitCodeLogonFail = 10
ExitCodeCreateProcessFail = 11
ExitCodeWorkingDirFail = 12
ExitCodeShellExecFail = 13
ExitCodeValidationFail = 14
)
type WindowsExecutorConfig struct {
Username string
Domain string
WorkingDir string
Shell string
Command string
Args []string
Interactive bool
Pty bool
PtyWidth int
PtyHeight int
}
type PrivilegeDropper struct{}
func NewPrivilegeDropper() *PrivilegeDropper {
return &PrivilegeDropper{}
}
var (
advapi32 = windows.NewLazyDLL("advapi32.dll")
procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId")
)
const (
logon32LogonNetwork = 3 // Network logon - no password required for authenticated users
// Common error messages
commandFlag = "-Command"
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
convertUsernameError = "convert username to UTF16: %w"
convertDomainError = "convert domain to UTF16: %w"
)
// CreateWindowsExecutorCommand creates a Windows command with privilege dropping.
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
if config.Username == "" {
return nil, 0, errors.New("username cannot be empty")
}
if config.Shell == "" {
return nil, 0, errors.New("shell cannot be empty")
}
shell := config.Shell
var shellArgs []string
if config.Command != "" {
shellArgs = []string{shell, commandFlag, config.Command}
} else {
shellArgs = []string{shell}
}
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
cmd, token, err := pd.CreateWindowsProcessAsUser(
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
if err != nil {
return nil, 0, fmt.Errorf("create Windows process as user: %w", err)
}
return cmd, token, nil
}
const (
// StatusSuccess represents successful LSA operation
StatusSuccess = 0
// KerbS4ULogonType message type for domain users with Kerberos
KerbS4ULogonType = 12
// Msv10s4ulogontype message type for local users with MSV1_0
Msv10s4ulogontype = 12
// MicrosoftKerberosNameA is the authentication package name for Kerberos
MicrosoftKerberosNameA = "Kerberos"
// Msv10packagename is the authentication package name for MSV1_0
Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"
NameSamCompatible = 2
NameUserPrincipal = 8
NameCanonical = 7
maxUPNLen = 1024
)
// kerbS4ULogon structure for S4U authentication (domain users)
type kerbS4ULogon struct {
MessageType uint32
Flags uint32
ClientUpn unicodeString
ClientRealm unicodeString
}
// msv10s4ulogon structure for S4U authentication (local users)
type msv10s4ulogon struct {
MessageType uint32
Flags uint32
UserPrincipalName unicodeString
DomainName unicodeString
}
// unicodeString structure
type unicodeString struct {
Length uint16
MaximumLength uint16
Buffer *uint16
}
// lsaString structure
type lsaString struct {
Length uint16
MaximumLength uint16
Buffer *byte
}
// tokenSource structure
type tokenSource struct {
SourceName [8]byte
SourceIdentifier windows.LUID
}
// quotaLimits structure
type quotaLimits struct {
PagedPoolLimit uint32
NonPagedPoolLimit uint32
MinimumWorkingSetSize uint32
MaximumWorkingSetSize uint32
PagefileLimit uint32
TimeLimit int64
}
var (
secur32 = windows.NewLazyDLL("secur32.dll")
procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess")
procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage")
procLsaLogonUser = secur32.NewProc("LsaLogonUser")
procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer")
procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess")
procTranslateNameW = secur32.NewProc("TranslateNameW")
)
// newLsaString creates an LsaString from a Go string
func newLsaString(s string) lsaString {
b := append([]byte(s), 0)
return lsaString{
Length: uint16(len(s)),
MaximumLength: uint16(len(b)),
Buffer: &b[0],
}
}
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)
pd := NewPrivilegeDropper()
isDomainUser := !pd.isLocalUser(domain)
lsaHandle, err := initializeLsaConnection()
if err != nil {
return 0, err
}
defer cleanupLsaConnection(lsaHandle)
authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser)
if err != nil {
return 0, err
}
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
if err != nil {
return 0, err
}
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
}
// buildUserCpn constructs the user principal name
func buildUserCpn(username, domain string) string {
if domain != "" && domain != "." {
return fmt.Sprintf(`%s\%s`, domain, username)
}
return username
}
// initializeLsaConnection establishes connection to LSA
func initializeLsaConnection() (windows.Handle, error) {
processName := newLsaString("NetBird")
var mode uint32
var lsaHandle windows.Handle
ret, _, _ := procLsaRegisterLogonProcess.Call(
uintptr(unsafe.Pointer(&processName)),
uintptr(unsafe.Pointer(&lsaHandle)),
uintptr(unsafe.Pointer(&mode)),
)
if ret != StatusSuccess {
return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret)
}
return lsaHandle, nil
}
// cleanupLsaConnection closes the LSA connection
func cleanupLsaConnection(lsaHandle windows.Handle) {
if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess {
log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret)
}
}
// lookupAuthenticationPackage finds the correct authentication package
func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) {
var authPackageName lsaString
if isDomainUser {
authPackageName = newLsaString(MicrosoftKerberosNameA)
} else {
authPackageName = newLsaString(Msv10packagename)
}
var authPackageId uint32
ret, _, _ := procLsaLookupAuthenticationPackage.Call(
uintptr(lsaHandle),
uintptr(unsafe.Pointer(&authPackageName)),
uintptr(unsafe.Pointer(&authPackageId)),
)
if ret != StatusSuccess {
return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret)
}
return authPackageId, nil
}
// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format)
func lookupPrincipalName(username, domain string) (string, error) {
samAccountName := fmt.Sprintf(`%s\%s`, domain, username)
samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName)
if err != nil {
return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err)
}
upnBuf := make([]uint16, maxUPNLen+1)
upnSize := uint32(len(upnBuf))
ret, _, _ := procTranslateNameW.Call(
uintptr(unsafe.Pointer(samAccountNameUtf16)),
uintptr(NameSamCompatible),
uintptr(NameUserPrincipal),
uintptr(unsafe.Pointer(&upnBuf[0])),
uintptr(unsafe.Pointer(&upnSize)),
)
if ret != 0 {
upn := windows.UTF16ToString(upnBuf[:upnSize])
log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn)
return upn, nil
}
upnSize = uint32(len(upnBuf))
ret, _, _ = procTranslateNameW.Call(
uintptr(unsafe.Pointer(samAccountNameUtf16)),
uintptr(NameSamCompatible),
uintptr(NameCanonical),
uintptr(unsafe.Pointer(&upnBuf[0])),
uintptr(unsafe.Pointer(&upnSize)),
)
if ret != 0 {
canonical := windows.UTF16ToString(upnBuf[:upnSize])
slashIdx := strings.IndexByte(canonical, '/')
if slashIdx > 0 {
fqdn := canonical[:slashIdx]
upn := fmt.Sprintf("%s@%s", username, fqdn)
log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical)
return upn, nil
}
}
log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName)
return samAccountName, nil
}
// prepareS4ULogonStructure creates the appropriate S4U logon structure
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
if isDomainUser {
return prepareDomainS4ULogon(username, domain)
}
return prepareLocalS4ULogon(username)
}
// prepareDomainS4ULogon creates S4U logon structure for domain users
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
upn, err := lookupPrincipalName(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
}
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
upnUtf16, err := windows.UTF16FromString(upn)
if err != nil {
return nil, 0, fmt.Errorf(convertUsernameError, err)
}
structSize := unsafe.Sizeof(kerbS4ULogon{})
upnByteSize := len(upnUtf16) * 2
logonInfoSize := structSize + uintptr(upnByteSize)
buffer := make([]byte, logonInfoSize)
logonInfo := unsafe.Pointer(&buffer[0])
s4uLogon := (*kerbS4ULogon)(logonInfo)
s4uLogon.MessageType = KerbS4ULogonType
s4uLogon.Flags = 0
upnOffset := structSize
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
s4uLogon.ClientUpn = unicodeString{
Length: uint16((len(upnUtf16) - 1) * 2),
MaximumLength: uint16(len(upnUtf16) * 2),
Buffer: upnBuffer,
}
s4uLogon.ClientRealm = unicodeString{}
return logonInfo, logonInfoSize, nil
}
// prepareLocalS4ULogon creates S4U logon structure for local users
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
usernameUtf16, err := windows.UTF16FromString(username)
if err != nil {
return nil, 0, fmt.Errorf(convertUsernameError, err)
}
domainUtf16, err := windows.UTF16FromString(".")
if err != nil {
return nil, 0, fmt.Errorf(convertDomainError, err)
}
structSize := unsafe.Sizeof(msv10s4ulogon{})
usernameByteSize := len(usernameUtf16) * 2
domainByteSize := len(domainUtf16) * 2
logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize)
buffer := make([]byte, logonInfoSize)
logonInfo := unsafe.Pointer(&buffer[0])
s4uLogon := (*msv10s4ulogon)(logonInfo)
s4uLogon.MessageType = Msv10s4ulogontype
s4uLogon.Flags = 0x0
usernameOffset := structSize
usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset))
copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16)
s4uLogon.UserPrincipalName = unicodeString{
Length: uint16((len(usernameUtf16) - 1) * 2),
MaximumLength: uint16(len(usernameUtf16) * 2),
Buffer: usernameBuffer,
}
domainOffset := usernameOffset + uintptr(usernameByteSize)
domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset))
copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16)
s4uLogon.DomainName = unicodeString{
Length: uint16((len(domainUtf16) - 1) * 2),
MaximumLength: uint16(len(domainUtf16) * 2),
Buffer: domainBuffer,
}
return logonInfo, logonInfoSize, nil
}
// performS4ULogon executes the S4U logon operation
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
var tokenSource tokenSource
copy(tokenSource.SourceName[:], "netbird")
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
log.Debugf("AllocateLocallyUniqueId failed")
}
originName := newLsaString("netbird")
var profile uintptr
var profileSize uint32
var logonId windows.LUID
var token windows.Handle
var quotas quotaLimits
var subStatus int32
ret, _, _ := procLsaLogonUser.Call(
uintptr(lsaHandle),
uintptr(unsafe.Pointer(&originName)),
logon32LogonNetwork,
uintptr(authPackageId),
uintptr(logonInfo),
logonInfoSize,
0,
uintptr(unsafe.Pointer(&tokenSource)),
uintptr(unsafe.Pointer(&profile)),
uintptr(unsafe.Pointer(&profileSize)),
uintptr(unsafe.Pointer(&logonId)),
uintptr(unsafe.Pointer(&token)),
uintptr(unsafe.Pointer(&quotas)),
uintptr(unsafe.Pointer(&subStatus)),
)
if profile != 0 {
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
}
}
if ret != StatusSuccess {
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
}
log.Debugf("created S4U %s token for user %s",
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
return token, nil
}
// createToken implements NetBird trust-based authentication using S4U
func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) {
fullUsername := buildUserCpn(username, domain)
if err := userExists(fullUsername, username, domain); err != nil {
return 0, err
}
isLocalUser := pd.isLocalUser(domain)
if isLocalUser {
return pd.authenticateLocalUser(username, fullUsername)
}
return pd.authenticateDomainUser(username, domain, fullUsername)
}
// userExists checks if the target useVerifier exists on the system
func userExists(fullUsername, username, domain string) error {
if _, err := lookupUser(fullUsername); err != nil {
log.Debugf("User %s not found: %v", fullUsername, err)
if domain != "" && domain != "." {
_, err = lookupUser(username)
}
if err != nil {
return fmt.Errorf("target user %s not found: %w", fullUsername, err)
}
}
return nil
}
// isLocalUser determines if this is a local user vs domain user
func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
hostname, err := os.Hostname()
if err != nil {
hostname = "localhost"
}
return domain == "" || domain == "." ||
strings.EqualFold(domain, hostname)
}
// authenticateLocalUser handles authentication for local users
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for local user %s", fullUsername)
token, err := generateS4UUserToken(username, ".")
if err != nil {
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
}
return token, nil
}
// authenticateDomainUser handles authentication for domain users
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
log.Debugf("using S4U authentication for domain user %s", fullUsername)
token, err := generateS4UUserToken(username, domain)
if err != nil {
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
}
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
return token, nil
}
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables).
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) {
token, err := pd.createToken(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("user authentication: %w", err)
}
defer func() {
if err := windows.CloseHandle(token); err != nil {
log.Debugf("close impersonation token: %v", err)
}
}()
cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
if err != nil {
return nil, 0, err
}
return cmd, primaryToken, nil
}
// createProcessWithToken creates process with the specified token and executable path.
// The caller must close the returned token handle after starting the process.
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
cmd.Dir = workingDir
var primaryToken windows.Token
err := windows.DuplicateTokenEx(
sourceToken,
windows.TOKEN_ALL_ACCESS,
nil,
windows.SecurityIdentification,
windows.TokenPrimary,
&primaryToken,
)
if err != nil {
return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err)
}
cmd.SysProcAttr = &syscall.SysProcAttr{
Token: syscall.Token(primaryToken),
}
return cmd, primaryToken, nil
}
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
return nil, fmt.Errorf("su command not available on Windows")
}

View File

@@ -0,0 +1,629 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
cryptossh "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT enforcement tests in short mode")
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
}
t.Logf("Detected server type: %s", serverType)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
})
t.Run("allows_when_disabled", func(t *testing.T) {
serverConfigNoJWT := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
serverNoJWT := New(serverConfigNoJWT)
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
assert.False(t, serverType.RequiresJWT())
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
require.NoError(t, err)
defer client.Close()
})
}
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64RawURLEncode(n),
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func base64RawURLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
// generateValidJWT creates a valid JWT token for testing
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
t.Helper()
addr := net.JoinHostPort(host, strconv.Itoa(port))
ctx := context.Background()
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
InsecureSkipVerify: true,
})
}
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
func TestJWTDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT detection test in short mode")
}
jwksServer, _, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
assert.True(t, serverType.RequiresJWT())
}
func TestJWTFailClose(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT fail-close tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
tokenClaims jwt.MapClaims
}{
{
name: "blocks_token_missing_iat",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
},
},
{
name: "blocks_token_missing_sub",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_iss",
tokenClaims: jwt.MapClaims{
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_aud",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_issuer",
tokenClaims: jwt.MapClaims{
"iss": "wrong-issuer",
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_audience",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": "wrong-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_expired_token",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
{
name: "blocks_token_exceeding_max_age",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(tokenString),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
assert.Error(t, err, "Authentication should fail (fail-close)")
})
}
}
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
func TestJWTAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT authentication tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
token string
wantAuthOK bool
setupServer func(*Server)
testOperation func(*testing.T, *cryptossh.Client, string) error
wantOpSuccess bool
}{
{
name: "allows_shell_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
return session.Shell()
},
wantOpSuccess: true,
},
{
name: "rejects_invalid_token",
token: "invalid",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_shell_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_command_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("ls")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "allows_sftp_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
return session.RequestSubsystem("sftp")
},
wantOpSuccess: true,
},
{
name: "blocks_sftp_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
err = session.RequestSubsystem("sftp")
if err == nil {
err = session.Wait()
}
return err
},
wantOpSuccess: false,
},
{
name: "allows_port_forward_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowRemotePortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: true,
},
{
name: "blocks_port_forward_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowLocalPortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// TODO: Skip port forwarding tests on Windows - user switching not supported
// These features are tested on Linux/Unix platforms
if runtime.GOOS == "windows" &&
(tc.name == "allows_port_forward_with_jwt" ||
tc.name == "blocks_port_forward_without_jwt") {
t.Skip("Skipping port forwarding test on Windows - covered by Linux tests")
}
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
if tc.setupServer != nil {
tc.setupServer(server)
}
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),
}
}
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: authMethods,
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed")
} else if err != nil {
t.Logf("Connection failed as expected: %v", err)
return
}
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
err = tc.testOperation(t, conn, serverAddr)
if tc.wantOpSuccess {
require.NoError(t, err, "Operation should succeed")
} else {
assert.Error(t, err, "Operation should fail")
}
})
}
}

View File

@@ -0,0 +1,386 @@
package server
import (
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
)
// SessionKey uniquely identifies an SSH session
type SessionKey string
// ConnectionKey uniquely identifies a port forwarding connection within a session
type ConnectionKey string
// ForwardKey uniquely identifies a port forwarding listener
type ForwardKey string
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
type tcpipForwardMsg struct {
Host string
Port uint32
}
// SetAllowLocalPortForwarding configures local port forwarding
func (s *Server) SetAllowLocalPortForwarding(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowLocalPortForwarding = allow
}
// SetAllowRemotePortForwarding configures remote port forwarding
func (s *Server) SetAllowRemotePortForwarding(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowRemotePortForwarding = allow
}
// configurePortForwarding sets up port forwarding callbacks
func (s *Server) configurePortForwarding(server *ssh.Server) {
allowLocal := s.allowLocalPortForwarding
allowRemote := s.allowRemotePortForwarding
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
if !allowLocal {
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
return false
}
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
return true
}
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
if !allowRemote {
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
return false
}
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
return false
}
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
return true
}
log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
}
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
// Returns nil if allowed, error if denied.
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
if ctx == nil {
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
}
username := ctx.User()
remoteAddr := "unknown"
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: false,
FeatureName: forwardType + " port forwarding",
})
if !result.Allowed {
return result.Error
}
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
forwardType, result.User.Username, port)
return nil
}
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
if !s.isRemotePortForwardingAllowed() {
logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
return false, nil
}
payload, err := s.parseTcpipForwardRequest(req)
if err != nil {
logger.Errorf("tcpip-forward unmarshal error: %v", err)
return false, nil
}
if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
logger.Warnf("tcpip-forward denied: %v", err)
return false, nil
}
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
sshConn, err := s.getSSHConnection(ctx)
if err != nil {
logger.Warnf("tcpip-forward request denied: %v", err)
return false, nil
}
return s.setupDirectForward(ctx, logger, sshConn, payload)
}
// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
logger := s.getRequestLogger(ctx)
var payload tcpipForwardMsg
if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
return false, nil
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
if s.removeRemoteForwardListener(key) {
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
return true, nil
}
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
return false, nil
}
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
defer func() {
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err)
} else {
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
}
}()
acceptChan := make(chan acceptResult, 1)
go func() {
for {
conn, err := ln.Accept()
select {
case acceptChan <- acceptResult{conn: conn, err: err}:
if err != nil {
return
}
case <-ctx.Done():
return
}
}
}()
for {
select {
case result := <-acceptChan:
if result.err != nil {
log.Debugf("remote forward accept error: %v", result.err)
return
}
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
case <-ctx.Done():
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
return
}
}
}
// getRequestLogger creates a logger with user and remote address context
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
remoteAddr := "unknown"
username := "unknown"
if ctx != nil {
if ctx.RemoteAddr() != nil {
remoteAddr = ctx.RemoteAddr().String()
}
username = ctx.User()
}
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
}
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
func (s *Server) isRemotePortForwardingAllowed() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.allowRemotePortForwarding
}
// parseTcpipForwardRequest parses the SSH request payload
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
var payload tcpipForwardMsg
err := cryptossh.Unmarshal(req.Payload, &payload)
return &payload, err
}
// getSSHConnection extracts SSH connection from context
func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
if ctx == nil {
return nil, fmt.Errorf("no context")
}
sshConnValue := ctx.Value(ssh.ContextKeyConn)
if sshConnValue == nil {
return nil, fmt.Errorf("no SSH connection in context")
}
sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
if !ok || sshConn == nil {
return nil, fmt.Errorf("invalid SSH connection in context")
}
return sshConn, nil
}
// setupDirectForward sets up a direct port forward
func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
ln, err := net.Listen("tcp", bindAddr)
if err != nil {
logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
return false, nil
}
actualPort := payload.Port
if payload.Port == 0 {
tcpAddr := ln.Addr().(*net.TCPAddr)
actualPort = uint32(tcpAddr.Port)
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
}
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
s.storeRemoteForwardListener(key, ln)
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
response := make([]byte, 4)
binary.BigEndian.PutUint32(response, actualPort)
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
return true, response
}
// acceptResult holds the result of a listener Accept() call
type acceptResult struct {
conn net.Conn
err error
}
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
sessionKey := s.findSessionKeyByContext(ctx)
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
logger := log.WithFields(log.Fields{
"session": sessionKey,
"conn": connID,
})
defer func() {
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}()
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if sshConn == nil {
logger.Debugf("remote forward: no SSH connection in context")
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
return
}
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
if err != nil {
logger.Debugf("open forward channel: %v", err)
return
}
s.proxyForwardConnection(ctx, logger, conn, channel)
}
// openForwardChannel creates an SSH forwarded-tcpip channel
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
payload := struct {
ConnectedAddress string
ConnectedPort uint32
OriginatorAddress string
OriginatorPort uint32
}{
ConnectedAddress: host,
ConnectedPort: port,
OriginatorAddress: remoteAddr.IP.String(),
OriginatorPort: uint32(remoteAddr.Port),
}
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
if err != nil {
return nil, fmt.Errorf("open SSH channel: %w", err)
}
go cryptossh.DiscardRequests(reqs)
return channel, nil
}
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(channel, conn); err != nil {
logger.Debugf("copy error (conn->channel): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn, channel); err != nil {
logger.Debugf("copy error (channel->conn): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
// First copy finished, wait for second copy or context cancellation
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
}
}
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}

712
client/ssh/server/server.go Normal file
View File

@@ -0,0 +1,712 @@
package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/version"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 22
// InternalSSHPort is the port SSH server listens on and is redirected to
const InternalSSHPort = 22022
const (
errWriteSession = "write session error: %v"
errExitSession = "exit session error: %v"
msgPrivilegedUserDisabled = "privileged user login is disabled"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found")
)
// PrivilegedUserError represents an error when privileged user login is disabled
type PrivilegedUserError struct {
Username string
}
func (e *PrivilegedUserError) Error() string {
return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username)
}
func (e *PrivilegedUserError) Is(target error) bool {
return target == ErrPrivilegedUserDisabled
}
// UserNotFoundError represents an error when a user cannot be found
type UserNotFoundError struct {
Username string
Cause error
}
func (e *UserNotFoundError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause)
}
return fmt.Sprintf("user %s not found", e.Username)
}
func (e *UserNotFoundError) Is(target error) bool {
return target == ErrUserNotFound
}
func (e *UserNotFoundError) Unwrap() error {
return e.Cause
}
// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors
func logSessionExitError(logger *log.Entry, err error) {
if err != nil && !errors.Is(err, io.EOF) {
logger.Warnf(errExitSession, err)
}
}
// safeLogCommand returns a safe representation of the command for logging
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<interactive shell>"
}
if len(cmd) == 1 {
return cmd[0]
}
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
}
type authKey string
func newAuthKey(username string, remoteAddr net.Addr) authKey {
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
}
type Server struct {
sshServer *ssh.Server
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
sessionCancels map[ConnectionKey]context.CancelFunc
sessionJWTUsers map[SessionKey]string
pendingAuthJWT map[authKey]string
allowLocalPortForwarding bool
allowRemotePortForwarding bool
allowRootLogin bool
allowSFTP bool
jwtEnabled bool
netstackNet *netstack.Net
wgAddress wgaddr.Address
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
suSupportsPty bool
}
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
}
// Config contains all SSH server configuration options
type Config struct {
// JWT authentication configuration. If nil, JWT authentication is disabled
JWT *JWTConfig
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
}
// SessionInfo contains information about an active SSH session
type SessionInfo struct {
Username string
RemoteAddress string
Command string
JWTUsername string
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
sessionJWTUsers: make(map[SessionKey]string),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
}
return s
}
// Start runs the SSH server
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer != nil {
return errors.New("SSH server is already running")
}
s.suSupportsPty = s.detectSuPtySupport(ctx)
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {
return fmt.Errorf("create listener: %w", err)
}
sshServer, err := s.createSSHServer(ln.Addr())
if err != nil {
s.closeListener(ln)
return fmt.Errorf("create SSH server: %w", err)
}
s.sshServer = sshServer
log.Infof("SSH server started on %s", addrDesc)
go func() {
if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Errorf("SSH server error: %v", err)
}
}()
return nil
}
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
if err != nil {
return nil, "", fmt.Errorf("listen on netstack: %w", err)
}
return ln, fmt.Sprintf("netstack %s", addr), nil
}
tcpAddr := net.TCPAddrFromAddrPort(addr)
lc := net.ListenConfig{}
ln, err := lc.Listen(ctx, "tcp", tcpAddr.String())
if err != nil {
return nil, "", fmt.Errorf("listen: %w", err)
}
return ln, addr.String(), nil
}
func (s *Server) closeListener(ln net.Listener) {
if ln == nil {
return
}
if err := ln.Close(); err != nil {
log.Debugf("listener close error: %v", err)
}
}
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.sshServer == nil {
return nil
}
if err := s.sshServer.Close(); err != nil {
log.Debugf("close SSH server: %v", err)
}
s.sshServer = nil
maps.Clear(s.sessions)
maps.Clear(s.sessionJWTUsers)
maps.Clear(s.pendingAuthJWT)
maps.Clear(s.sshConnections)
for _, cancelFunc := range s.sessionCancels {
cancelFunc()
}
maps.Clear(s.sessionCancels)
for _, listener := range s.remoteForwardListeners {
if err := listener.Close(); err != nil {
log.Debugf("close remote forward listener: %v", err)
}
}
maps.Clear(s.remoteForwardListeners)
return nil
}
// GetStatus returns the current status of the SSH server and active sessions
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
s.mu.RLock()
defer s.mu.RUnlock()
enabled = s.sshServer != nil
for sessionKey, session := range s.sessions {
cmd := "<interactive shell>"
if len(session.Command()) > 0 {
cmd = safeLogCommand(session.Command())
}
jwtUsername := s.sessionJWTUsers[sessionKey]
sessions = append(sessions, SessionInfo{
Username: session.User(),
RemoteAddress: session.RemoteAddr().String(),
Command: cmd,
JWTUsername: jwtUsername,
})
}
return enabled, sessions
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = net
}
// SetNetworkValidation configures network-based connection filtering
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.mu.Lock()
defer s.mu.Unlock()
s.wgAddress = addr
}
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
s.mu.RUnlock()
return nil
}
config := s.jwtConfig
s.mu.RUnlock()
if config == nil {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.KeysLocation,
true,
)
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
)
s.mu.Lock()
defer s.mu.Unlock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
s.jwtValidator = validator
s.jwtExtractor = extractor
log.Infof("JWT validator initialized successfully")
return nil
}
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
s.mu.RLock()
jwtValidator := s.jwtValidator
jwtConfig := s.jwtConfig
s.mu.RUnlock()
if jwtValidator == nil {
return nil, fmt.Errorf("JWT validator not initialized")
}
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
}
}
return nil, fmt.Errorf("validate token: %w", err)
}
if err := s.checkTokenAge(token, jwtConfig); err != nil {
return nil, err
}
return token, nil
}
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil {
return nil
}
maxTokenAge := jwtConfig.MaxTokenAge
if maxTokenAge <= 0 {
maxTokenAge = DefaultJWTMaxTokenAge
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
if jwtExtractor == nil {
userID := extractUserID(token)
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
}
userAuth, err := jwtExtractor.ToUserAuth(token)
if err != nil {
userID := extractUserID(token)
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
}
if !s.hasSSHAccess(&userAuth) {
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
}
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
return userAuth.UserId != ""
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
return claims, nil
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
s.mu.Lock()
s.pendingAuthJWT[key] = userAuth.UserId
s.mu.Unlock()
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
if state, exists := s.sshConnections[sshConn]; exists {
state.hasActivePortForward = true
} else {
s.sshConnections[sshConn] = &sshConnectionState{
hasActivePortForward: true,
username: username,
remoteAddr: remoteAddr,
}
}
}
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
}
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
if ctx == nil {
return "unknown"
}
// Try to match by SSH connection
sshConn := ctx.Value(ssh.ContextKeyConn)
if sshConn == nil {
return "unknown"
}
s.mu.RLock()
defer s.mu.RUnlock()
// Look through sessions to find one with matching connection
for sessionKey, session := range s.sessions {
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
return sessionKey
}
}
// If no session found, this might be during early connection setup
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
}
return "unknown"
}
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
localIP := s.wgAddress.IP
s.mu.RUnlock()
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
return conn
}
remoteAddr := conn.RemoteAddr()
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok {
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
return nil
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
return nil
}
// Block connections from our own IP (prevent local apps from connecting to ourselves)
if remoteIP == localIP {
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
return nil
}
if !netbirdNetwork.Contains(remoteIP) {
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
return conn
}
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err)
}
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
if s.jwtEnabled {
serverVersion += " " + detection.JWTRequiredMarker
}
server := &ssh.Server{
Addr: addr.String(),
Handler: s.sessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": s.sftpSubsystemHandler,
},
HostSigners: []ssh.Signer{},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": s.directTCPIPHandler,
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": s.tcpipForwardHandler,
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
}
if s.jwtEnabled {
server.PasswordHandler = s.passwordHandler
}
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
if err := server.SetOption(hostKeyPEM); err != nil {
return nil, fmt.Errorf("set host key: %w", err)
}
s.configurePortForwarding(server)
return server, nil
}
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
ln, exists := s.remoteForwardListeners[key]
if !exists {
return false
}
delete(s.remoteForwardListeners, key)
if err := ln.Close(); err != nil {
log.Debugf("remote forward listener close error: %v", err)
}
return true
}
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
var payload struct {
Host string
Port uint32
OriginatorAddr string
OriginatorPort uint32
}
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
log.Debugf("channel reject error: %v", err)
}
return
}
s.mu.RLock()
allowLocal := s.allowLocalPortForwarding
s.mu.RUnlock()
if !allowLocal {
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
return
}
// Check privilege requirements for the destination port
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
return
}
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
}

View File

@@ -0,0 +1,394 @@
package server
import (
"context"
"fmt"
"net"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/ssh"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
)
func TestServer_RootLoginRestriction(t *testing.T) {
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowRoot bool
username string
expectError bool
description string
}{
{
name: "root login allowed",
allowRoot: true,
username: "root",
expectError: false,
description: "Root login should succeed when allowed",
},
{
name: "root login denied",
allowRoot: false,
username: "root",
expectError: true,
description: "Root login should fail when disabled",
},
{
name: "regular user login always allowed",
allowRoot: false,
username: "testuser",
expectError: false,
description: "Regular user login should work regardless of root setting",
},
}
// Add Windows Administrator tests if on Windows
if runtime.GOOS == "windows" {
tests = append(tests, []struct {
name string
allowRoot bool
username string
expectError bool
description string
}{
{
name: "Administrator login allowed",
allowRoot: true,
username: "Administrator",
expectError: false,
description: "Administrator login should succeed when allowed",
},
{
name: "Administrator login denied",
allowRoot: false,
username: "Administrator",
expectError: true,
description: "Administrator login should fail when disabled",
},
{
name: "administrator login denied (lowercase)",
allowRoot: false,
username: "administrator",
expectError: true,
description: "administrator login should fail when disabled (case insensitive)",
},
}...)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock privileged environment to test root access controls
// Set up mock users based on platform
mockUsers := map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
"testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"),
}
// Add Windows-specific users for Administrator tests
if runtime.GOOS == "windows" {
mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator")
mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator")
}
cleanup := setupTestDependencies(
createTestUser("root", "0", "0", "/root"), // Running as root
nil,
runtime.GOOS,
0, // euid 0 (root)
mockUsers,
nil,
)
defer cleanup()
// Create server with specific configuration
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(tt.allowRoot)
// Test the userNameLookup method directly
user, err := server.userNameLookup(tt.username)
if tt.expectError {
assert.Error(t, err, tt.description)
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
// Check for appropriate error message based on platform capabilities
errorMsg := err.Error()
// Either privileged user restriction OR user switching limitation
hasPrivilegedError := strings.Contains(errorMsg, "privileged user")
hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported")
assert.True(t, hasPrivilegedError || hasSwitchingError,
"Expected privileged user or user switching error, got: %s", errorMsg)
}
} else {
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
// For privileged users, we expect either success or a different error
// (like user not found), but not the "login disabled" error
if err != nil {
assert.NotContains(t, err.Error(), "privileged user login is disabled")
}
} else {
// For regular users, lookup should generally succeed or fall back gracefully
// Note: may return current user as fallback
assert.NotNil(t, user)
}
}
})
}
}
func TestServer_PortForwardingRestriction(t *testing.T) {
// Test that the port forwarding callbacks properly respect configuration flags
// This is a unit test of the callback logic, not a full integration test
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
tests := []struct {
name string
allowLocalForwarding bool
allowRemoteForwarding bool
description string
}{
{
name: "all forwarding allowed",
allowLocalForwarding: true,
allowRemoteForwarding: true,
description: "Both local and remote forwarding should be allowed",
},
{
name: "local forwarding disabled",
allowLocalForwarding: false,
allowRemoteForwarding: true,
description: "Local forwarding should be denied when disabled",
},
{
name: "remote forwarding disabled",
allowLocalForwarding: true,
allowRemoteForwarding: false,
description: "Remote forwarding should be denied when disabled",
},
{
name: "all forwarding disabled",
allowLocalForwarding: false,
allowRemoteForwarding: false,
description: "Both forwarding types should be denied when disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create server with specific configuration
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
// We need to access the internal configuration to simulate the callback tests
// Since the callbacks are created inside the Start method, we'll test the logic directly
// Test the configuration values are set correctly
server.mu.RLock()
allowLocal := server.allowLocalPortForwarding
allowRemote := server.allowRemotePortForwarding
server.mu.RUnlock()
assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly")
assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly")
// Simulate the callback logic
localResult := allowLocal // This would be the callback return value
remoteResult := allowRemote // This would be the callback return value
assert.Equal(t, tt.allowLocalForwarding, localResult,
"Local port forwarding callback should return correct value")
assert.Equal(t, tt.allowRemoteForwarding, remoteResult,
"Remote port forwarding callback should return correct value")
})
}
}
func TestServer_PortConflictHandling(t *testing.T) {
// Test that multiple sessions requesting the same local port are handled naturally by the OS
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Get a free port for testing
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
testPort := ln.Addr().(*net.TCPAddr).Port
err = ln.Close()
require.NoError(t, err)
// Connect first client
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1()
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client1.Close()
assert.NoError(t, err)
}()
// Connect second client
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client2.Close()
assert.NoError(t, err)
}()
// First client binds to the test port
localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort)
remoteAddr := "127.0.0.1:80"
// Start first client's port forwarding
done1 := make(chan error, 1)
go func() {
// This should succeed and hold the port
err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr)
done1 <- err
}()
// Give first client time to bind
time.Sleep(200 * time.Millisecond)
// Second client tries to bind to same port
localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort)
shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer shortCancel()
err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr)
// Second client should fail due to "address already in use"
assert.Error(t, err, "Second client should fail to bind to same port")
if err != nil {
// The error should indicate the address is already in use
errMsg := strings.ToLower(err.Error())
if runtime.GOOS == "windows" {
assert.Contains(t, errMsg, "only one usage of each socket address",
"Error should indicate port conflict")
} else {
assert.Contains(t, errMsg, "address already in use",
"Error should indicate port conflict")
}
}
// Cancel first client's context and wait for it to finish
cancel1()
select {
case err1 := <-done1:
// Should get context cancelled or deadline exceeded
assert.Error(t, err1, "First client should exit when context cancelled")
case <-time.After(2 * time.Second):
t.Error("First client did not exit within timeout")
}
}
func TestServer_IsPrivilegedUser(t *testing.T) {
tests := []struct {
username string
expected bool
description string
}{
{
username: "root",
expected: true,
description: "root should be considered privileged",
},
{
username: "regular",
expected: false,
description: "regular user should not be privileged",
},
{
username: "",
expected: false,
description: "empty username should not be privileged",
},
}
// Add Windows-specific tests
if runtime.GOOS == "windows" {
tests = append(tests, []struct {
username string
expected bool
description string
}{
{
username: "Administrator",
expected: true,
description: "Administrator should be considered privileged on Windows",
},
{
username: "administrator",
expected: true,
description: "administrator should be considered privileged on Windows (case insensitive)",
},
}...)
} else {
// On non-Windows systems, Administrator should not be privileged
tests = append(tests, []struct {
username string
expected bool
description string
}{
{
username: "Administrator",
expected: false,
description: "Administrator should not be privileged on non-Windows systems",
},
}...)
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
result := isPrivilegedUsername(tt.username)
assert.Equal(t, tt.expected, result, tt.description)
})
}
}

View File

@@ -0,0 +1,441 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"os/user"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop()
assert.NoError(t, err)
}
func TestSSHServerIntegration(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with random port
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server in background
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
// Get a free port
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key for verification
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Create SSH client config
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// Connect to SSH server
client, err := cryptossh.Dial("tcp", serverAddr, config)
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("close client: %v", err)
}
}()
// Test creating a session
session, err := client.NewSession()
require.NoError(t, err)
defer func() {
if err := session.Close(); err != nil {
t.Logf("close session: %v", err)
}
}()
// Note: Since we don't have a real shell environment in tests,
// we can't test actual command execution, but we can verify
// the connection and authentication work
t.Log("SSH connection and authentication successful")
}
func TestSSHServerMultipleConnections(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// Test multiple concurrent connections
const numConnections = 5
results := make(chan error, numConnections)
for i := 0; i < numConnections; i++ {
go func(id int) {
client, err := cryptossh.Dial("tcp", serverAddr, config)
if err != nil {
results <- fmt.Errorf("connection %d failed: %w", id, err)
return
}
defer func() {
_ = client.Close() // Ignore error in test goroutine
}()
session, err := client.NewSession()
if err != nil {
results <- fmt.Errorf("session %d failed: %w", id, err)
return
}
defer func() {
_ = session.Close() // Ignore error in test goroutine
}()
results <- nil
}(i)
}
// Wait for all connections to complete
for i := 0; i < numConnections; i++ {
select {
case err := <-results:
assert.NoError(t, err)
case <-time.After(10 * time.Second):
t.Fatalf("Connection %d timed out", i)
}
}
}
func TestSSHServerNoAuthMode(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Generate a client private key for SSH protocol (server doesn't check it)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with client key
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(clientSigner),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil {
assert.NoError(t, conn.Close())
}
}
func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles
for i := 0; i < 3; i++ {
t.Logf("Start/stop cycle %d", i+1)
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case <-started:
case err := <-errChan:
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
case <-time.After(5 * time.Second):
t.Fatalf("Cycle %d: Server start timeout", i+1)
}
err = server.Stop()
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
}
}
func TestSSHServer_WindowsShellHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping Windows shell test in short mode")
}
server := &Server{}
if runtime.GOOS == "windows" {
// Test Windows cmd.exe shell behavior
args := server.getShellCommandArgs("cmd.exe", "echo test")
assert.Equal(t, "cmd.exe", args[0])
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
// Test PowerShell behavior
args = server.getShellCommandArgs("powershell.exe", "echo test")
assert.Equal(t, "powershell.exe", args[0])
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
} else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-l", args[1])
assert.Equal(t, "-c", args[2])
assert.Equal(t, "echo test", args[3])
}
}
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig1 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
server2.SetAllowLocalPortForwarding(true)
server2.SetAllowRemotePortForwarding(true)
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
}

View File

@@ -0,0 +1,168 @@
package server
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
)
// sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) {
sessionKey := s.registerSession(session)
key := newAuthKey(session.User(), session.RemoteAddr())
s.mu.Lock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername != "" {
s.sessionJWTUsers[sessionKey] = jwtUsername
delete(s.pendingAuthJWT, key)
}
s.mu.Unlock()
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
} else {
logger.Infof("SSH session started")
}
sessionStart := time.Now()
defer s.unregisterSession(sessionKey, session)
defer func() {
duration := time.Since(sessionStart).Round(time.Millisecond)
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
logger.Warnf("close session after %v: %v", duration, err)
}
logger.Infof("SSH session closed after %v", duration)
}()
privilegeResult, err := s.userPrivilegeCheck(session.User())
if err != nil {
s.handlePrivError(logger, session, err)
return
}
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
switch {
case isPty && hasCommand:
// ssh -t <host> <cmd> - Pty command execution
s.handleCommand(logger, session, privilegeResult, winCh)
case isPty:
// ssh <host> - Pty interactive session (login)
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
case hasCommand:
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
s.rejectInvalidSession(logger, session)
}
}
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
}
func (s *Server) registerSession(session ssh.Session) SessionKey {
sessionID := session.Context().Value(ssh.ContextKeySessionID)
if sessionID == nil {
sessionID = fmt.Sprintf("%p", session)
}
// Create a short 4-byte identifier from the full session ID
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
hash := hasher.Sum(nil)
shortID := hex.EncodeToString(hash[:4])
remoteAddr := session.RemoteAddr().String()
username := session.User()
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
s.mu.Lock()
s.sessions[sessionKey] = session
s.mu.Unlock()
return sessionKey
}
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
s.mu.Lock()
delete(s.sessions, sessionKey)
delete(s.sessionJWTUsers, sessionKey)
// Cancel all port forwarding connections for this session
var connectionsToCancel []ConnectionKey
for key := range s.sessionCancels {
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
connectionsToCancel = append(connectionsToCancel, key)
}
}
for _, key := range connectionsToCancel {
if cancelFunc, exists := s.sessionCancels[key]; exists {
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
cancelFunc()
delete(s.sessionCancels, key)
}
}
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
delete(s.sshConnections, sshConn)
}
}
s.mu.Unlock()
}
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
logger.Warnf("user privilege check failed: %v", err)
errorMsg := s.buildUserLookupErrorMessage(err)
if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if exitErr := session.Exit(1); exitErr != nil {
logSessionExitError(logger, exitErr)
}
}
// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
func (s *Server) buildUserLookupErrorMessage(err error) string {
var privilegedErr *PrivilegedUserError
switch {
case errors.As(err, &privilegedErr):
if privilegedErr.Username == "root" {
return "root login is disabled on this SSH server\n"
}
return "privileged user access is disabled on this SSH server\n"
case errors.Is(err, ErrPrivilegeRequired):
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
case errors.Is(err, ErrPrivilegedUserSwitch):
return "Cannot switch to privileged user - current user lacks required privileges\n"
default:
return "User authentication failed\n"
}
}

View File

@@ -0,0 +1,22 @@
//go:build js
package server
import (
"fmt"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// handlePty is not supported on JS/WASM
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
return false
}

81
client/ssh/server/sftp.go Normal file
View File

@@ -0,0 +1,81 @@
package server
import (
"fmt"
"io"
"github.com/gliderlabs/ssh"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
)
// SetAllowSFTP enables or disables SFTP support
func (s *Server) SetAllowSFTP(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowSFTP = allow
}
// sftpSubsystemHandler handles SFTP subsystem requests
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
s.mu.RLock()
allowSFTP := s.allowSFTP
s.mu.RUnlock()
if !allowSFTP {
log.Debugf("SFTP subsystem request denied: SFTP disabled")
if err := sess.Exit(1); err != nil {
log.Debugf("SFTP session exit failed: %v", err)
}
return
}
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: sess.User(),
FeatureSupportsUserSwitch: true,
FeatureName: FeatureSFTP,
})
if !result.Allowed {
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
if err := sess.Exit(1); err != nil {
log.Debugf("exit SFTP session: %v", err)
}
return
}
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
if !result.RequiresUserSwitching {
if err := s.executeSftpDirect(sess); err != nil {
log.Errorf("SFTP direct execution: %v", err)
}
return
}
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
log.Errorf("SFTP privilege drop execution: %v", err)
}
}
// executeSftpDirect executes SFTP directly without privilege dropping
func (s *Server) executeSftpDirect(sess ssh.Session) error {
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
sftpServer, err := sftp.NewServer(sess)
if err != nil {
return fmt.Errorf("SFTP server creation: %w", err)
}
defer func() {
if err := sftpServer.Close(); err != nil {
log.Debugf("failed to close sftp server: %v", err)
}
}()
if err := sftpServer.Serve(); err != nil && err != io.EOF {
return fmt.Errorf("serve: %w", err)
}
return nil
}

View File

@@ -0,0 +1,12 @@
//go:build js
package server
import (
"os/user"
)
// parseUserCredentials is not supported on JS/WASM
func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) {
return 0, 0, nil, errNotSupported
}

View File

@@ -0,0 +1,228 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"os/user"
"testing"
"time"
"github.com/pkg/sftp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
)
func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Skip SFTP test when running as root due to protocol issues in some environments
if os.Geteuid() == 0 {
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
}
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server with SFTP enabled
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true)
server.SetAllowRootLogin(true)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// (currentUser already obtained at function start)
// Create SSH client connection
clientConfig := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 5 * time.Second,
}
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
require.NoError(t, err, "SSH connection should succeed")
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
// Create SFTP client
sftpClient, err := sftp.NewClient(conn)
require.NoError(t, err, "SFTP client creation should succeed")
defer func() {
if err := sftpClient.Close(); err != nil {
t.Logf("SFTP client close error: %v", err)
}
}()
// Test basic SFTP operations
workingDir, err := sftpClient.Getwd()
assert.NoError(t, err, "Should be able to get working directory")
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
// Test directory listing
files, err := sftpClient.ReadDir(".")
assert.NoError(t, err, "Should be able to list current directory")
assert.NotNil(t, files, "File list should not be nil")
}
func TestSSHServer_SFTPDisabled(t *testing.T) {
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server with SFTP disabled
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// (currentUser already obtained at function start)
// Create SSH client connection
clientConfig := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 5 * time.Second,
}
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
require.NoError(t, err, "SSH connection should succeed")
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
// Try to create SFTP client - should fail when SFTP is disabled
_, err = sftp.NewClient(conn)
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
}

View File

@@ -0,0 +1,71 @@
//go:build !windows
package server
import (
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
uid, gid, groups, err := s.parseUserCredentials(targetUser)
if err != nil {
return fmt.Errorf("parse user credentials: %w", err)
}
sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir)
if err != nil {
return fmt.Errorf("create executor: %w", err)
}
sftpCmd.Stdin = sess
sftpCmd.Stdout = sess
sftpCmd.Stderr = sess.Stderr()
log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid)
if err := sftpCmd.Start(); err != nil {
return fmt.Errorf("starting SFTP executor: %w", err)
}
if err := sftpCmd.Wait(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
log.Tracef("SFTP process exited with code %d", exitError.ExitCode())
return nil
}
return fmt.Errorf("exec: %w", err)
}
return nil
}
// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping
func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) {
netbirdPath, err := os.Executable()
if err != nil {
return nil, err
}
args := []string{
"ssh", "sftp",
"--uid", strconv.FormatUint(uint64(uid), 10),
"--gid", strconv.FormatUint(uint64(gid), 10),
"--working-dir", workingDir,
}
for _, group := range groups {
args = append(args, "--groups", strconv.FormatUint(uint64(group), 10))
}
log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args)
return exec.CommandContext(sess.Context(), netbirdPath, args...), nil
}

View File

@@ -0,0 +1,91 @@
//go:build windows
package server
import (
"errors"
"fmt"
"os"
"os/exec"
"os/user"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
// createSftpCommand creates a Windows SFTP command with user switching.
// The caller must close the returned token handle after starting the process.
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
username, domain := s.parseUsername(targetUser.Username)
netbirdPath, err := os.Executable()
if err != nil {
return nil, 0, fmt.Errorf("get netbird executable path: %w", err)
}
args := []string{
"ssh", "sftp",
"--working-dir", targetUser.HomeDir,
"--windows-username", username,
"--windows-domain", domain,
}
pd := NewPrivilegeDropper()
token, err := pd.createToken(username, domain)
if err != nil {
return nil, 0, fmt.Errorf("create token: %w", err)
}
defer func() {
if err := windows.CloseHandle(token); err != nil {
log.Warnf("failed to close impersonation token: %v", err)
}
}()
cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
if err != nil {
return nil, 0, fmt.Errorf("create SFTP command: %w", err)
}
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
return cmd, primaryToken, nil
}
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error {
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
}
}()
sftpCmd.Stdin = sess
sftpCmd.Stdout = sess
sftpCmd.Stderr = sess.Stderr()
if err := sftpCmd.Start(); err != nil {
return fmt.Errorf("starting sftp executor: %w", err)
}
if err := sftpCmd.Wait(); err != nil {
var exitError *exec.ExitError
if errors.As(err, &exitError) {
log.Tracef("sftp process exited with code %d", exitError.ExitCode())
return nil
}
return fmt.Errorf("exec sftp: %w", err)
}
return nil
}
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
sftpCmd, token, err := s.createSftpCommand(targetUser, sess)
if err != nil {
return fmt.Errorf("create sftp: %w", err)
}
return s.executeSftpCommand(sess, sftpCmd, token)
}

180
client/ssh/server/shell.go Normal file
View File

@@ -0,0 +1,180 @@
package server
import (
"bufio"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
const (
defaultUnixShell = "/bin/sh"
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
powershellExe = "powershell.exe"
)
// getUserShell returns the appropriate shell for the given user ID
// Handles all platform-specific logic and fallbacks consistently
func getUserShell(userID string) string {
switch runtime.GOOS {
case "windows":
return getWindowsUserShell()
default:
return getUnixUserShell(userID)
}
}
// getWindowsUserShell returns the best shell for Windows users.
// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
// PowerShell provides safer argument handling and is available on all modern Windows systems.
// Order: pwsh.exe -> powershell.exe
func getWindowsUserShell() string {
if path, err := exec.LookPath(pwshExe); err == nil {
return path
}
if path, err := exec.LookPath(powershellExe); err == nil {
return path
}
return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
}
// getUnixUserShell returns the shell for Unix-like systems
func getUnixUserShell(userID string) string {
shell := getShellFromPasswd(userID)
if shell != "" {
return shell
}
if shell := os.Getenv("SHELL"); shell != "" {
return shell
}
return defaultUnixShell
}
// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
func getShellFromPasswd(userID string) string {
file, err := os.Open("/etc/passwd")
if err != nil {
return ""
}
defer func() {
if err := file.Close(); err != nil {
log.Warnf("close /etc/passwd file: %v", err)
}
}()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Split(line, ":")
if len(fields) < 7 {
continue
}
// field 2 is UID
if fields[2] == userID {
shell := strings.TrimSpace(fields[6])
return shell
}
}
if err := scanner.Err(); err != nil {
log.Warnf("error reading /etc/passwd: %v", err)
}
return ""
}
// prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
}
return []string{
fmt.Sprint("SHELL=" + shell),
fmt.Sprint("USER=" + user.Username),
fmt.Sprint("LOGNAME=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir),
"PATH=" + pathValue,
}
}
// acceptEnv checks if environment variable from SSH client should be accepted
// This is a whitelist of variables that SSH clients can send to the server
func acceptEnv(envVar string) bool {
varName := envVar
if idx := strings.Index(envVar, "="); idx != -1 {
varName = envVar[:idx]
}
exactMatches := []string{
"LANG",
"LANGUAGE",
"TERM",
"COLORTERM",
"EDITOR",
"VISUAL",
"PAGER",
"LESS",
"LESSCHARSET",
"TZ",
}
prefixMatches := []string{
"LC_",
}
for _, exact := range exactMatches {
if varName == exact {
return true
}
}
for _, prefix := range prefixMatches {
if strings.HasPrefix(varName, prefix) {
return true
}
}
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}

45
client/ssh/server/test.go Normal file
View File

@@ -0,0 +1,45 @@
package server
import (
"context"
"fmt"
"net"
"net/netip"
"testing"
"time"
)
func StartTestServer(t *testing.T, server *Server) string {
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort := netip.MustParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
return actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
return ""
}

View File

@@ -0,0 +1,411 @@
package server
import (
"errors"
"fmt"
"os"
"os/user"
"runtime"
"strings"
log "github.com/sirupsen/logrus"
)
var (
ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges")
ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges")
)
// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.)
func isPlatformUnix() bool {
return getCurrentOS() != "windows"
}
// Dependency injection variables for testing - allows mocking dynamic runtime checks
var (
getCurrentUser = user.Current
lookupUser = user.Lookup
getCurrentOS = func() string { return runtime.GOOS }
getIsProcessPrivileged = isCurrentProcessPrivileged
getEuid = os.Geteuid
)
const (
// FeatureSSHLogin represents SSH login operations for privilege checking
FeatureSSHLogin = "SSH login"
// FeatureSFTP represents SFTP operations for privilege checking
FeatureSFTP = "SFTP"
)
// PrivilegeCheckRequest represents a privilege check request
type PrivilegeCheckRequest struct {
// Username being requested (empty = current user)
RequestedUsername string
FeatureSupportsUserSwitch bool // Does this feature/operation support user switching?
FeatureName string
}
// PrivilegeCheckResult represents the result of a privilege check
type PrivilegeCheckResult struct {
// Allowed indicates whether the privilege check passed
Allowed bool
// User is the effective user to use for the operation (nil if not allowed)
User *user.User
// Error contains the reason for denial (nil if allowed)
Error error
// UsedFallback indicates we fell back to current user instead of requested user.
// This happens on Unix when running as an unprivileged user (e.g., in containers)
// where there's no point in user switching since we lack privileges anyway.
// When true, all privilege checks have already been performed and no additional
// privilege dropping or root checks are needed - the current user is the target.
UsedFallback bool
// RequiresUserSwitching indicates whether user switching will actually occur
// (false for fallback cases where no actual switching happens)
RequiresUserSwitching bool
}
// CheckPrivileges performs comprehensive privilege checking for all SSH features.
// This is the single source of truth for privilege decisions across the SSH server.
func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult {
context, err := s.buildPrivilegeCheckContext(req.FeatureName)
if err != nil {
return PrivilegeCheckResult{Allowed: false, Error: err}
}
// Handle empty username case - but still check root access controls
if req.RequestedUsername == "" {
if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot {
return PrivilegeCheckResult{
Allowed: false,
Error: &PrivilegedUserError{Username: context.currentUser.Username},
}
}
return PrivilegeCheckResult{
Allowed: true,
User: context.currentUser,
RequiresUserSwitching: false,
}
}
return s.checkUserRequest(context, req)
}
// buildPrivilegeCheckContext gathers all the context needed for privilege checking
func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) {
currentUser, err := getCurrentUser()
if err != nil {
return nil, fmt.Errorf("get current user for %s: %w", featureName, err)
}
s.mu.RLock()
allowRoot := s.allowRootLogin
s.mu.RUnlock()
return &privilegeCheckContext{
currentUser: currentUser,
currentUserPrivileged: getIsProcessPrivileged(),
allowRoot: allowRoot,
}, nil
}
// checkUserRequest handles normal privilege checking flow for specific usernames
func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult {
if !ctx.currentUserPrivileged && isPlatformUnix() {
log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)",
ctx.currentUser.Username, req.FeatureName, req.RequestedUsername)
return PrivilegeCheckResult{
Allowed: true,
User: ctx.currentUser,
UsedFallback: true,
RequiresUserSwitching: false,
}
}
resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername)
if err != nil {
// Calculate if user switching would be required even if lookup failed
needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username)
return PrivilegeCheckResult{
Allowed: false,
Error: err,
RequiresUserSwitching: needsUserSwitching,
}
}
needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser)
if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot {
return PrivilegeCheckResult{
Allowed: false,
Error: &PrivilegedUserError{Username: resolvedUser.Username},
RequiresUserSwitching: needsUserSwitching,
}
}
if needsUserSwitching && !req.FeatureSupportsUserSwitch {
return PrivilegeCheckResult{
Allowed: false,
Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName),
RequiresUserSwitching: needsUserSwitching,
}
}
return PrivilegeCheckResult{
Allowed: true,
User: resolvedUser,
RequiresUserSwitching: needsUserSwitching,
}
}
// resolveRequestedUser resolves a username to its canonical user identity
func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) {
if requestedUsername == "" {
return getCurrentUser()
}
if err := validateUsername(requestedUsername); err != nil {
return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err)
}
u, err := lookupUser(requestedUsername)
if err != nil {
return nil, &UserNotFoundError{Username: requestedUsername, Cause: err}
}
return u, nil
}
// isSameResolvedUser compares two resolved user identities
func isSameResolvedUser(user1, user2 *user.User) bool {
if user1 == nil || user2 == nil {
return user1 == user2
}
return user1.Uid == user2.Uid
}
// privilegeCheckContext holds all context needed for privilege checking
type privilegeCheckContext struct {
currentUser *user.User
currentUserPrivileged bool
allowRoot bool
}
// isSameUser checks if two usernames refer to the same user
// SECURITY: This function must be conservative - it should only return true
// when we're certain both usernames refer to the exact same user identity
func isSameUser(requestedUsername, currentUsername string) bool {
// Empty requested username means current user
if requestedUsername == "" {
return true
}
// Exact match (most common case)
if getCurrentOS() == "windows" {
if strings.EqualFold(requestedUsername, currentUsername) {
return true
}
} else {
if requestedUsername == currentUsername {
return true
}
}
// Windows domain resolution: only allow domain stripping when comparing
// a bare username against the current user's domain-qualified name
if getCurrentOS() == "windows" {
return isWindowsSameUser(requestedUsername, currentUsername)
}
return false
}
// isWindowsSameUser handles Windows-specific user comparison with domain logic
func isWindowsSameUser(requestedUsername, currentUsername string) bool {
// Extract domain and username parts
extractParts := func(name string) (domain, user string) {
// Handle DOMAIN\username format
if idx := strings.LastIndex(name, `\`); idx != -1 {
return name[:idx], name[idx+1:]
}
// Handle user@domain.com format
if idx := strings.Index(name, "@"); idx != -1 {
return name[idx+1:], name[:idx]
}
// No domain specified - local machine
return "", name
}
reqDomain, reqUser := extractParts(requestedUsername)
curDomain, curUser := extractParts(currentUsername)
// Case-insensitive username comparison
if !strings.EqualFold(reqUser, curUser) {
return false
}
// If requested username has no domain, it refers to local machine user
// Allow this to match the current user regardless of current user's domain
if reqDomain == "" {
return true
}
// If both have domains, they must match exactly (case-insensitive)
return strings.EqualFold(reqDomain, curDomain)
}
// SetAllowRootLogin configures root login access
func (s *Server) SetAllowRootLogin(allow bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.allowRootLogin = allow
}
// userNameLookup performs user lookup with root login permission check
func (s *Server) userNameLookup(username string) (*user.User, error) {
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: true,
FeatureName: FeatureSSHLogin,
})
if !result.Allowed {
return nil, result.Error
}
return result.User, nil
}
// userPrivilegeCheck performs user lookup with full privilege check result
func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) {
result := s.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: username,
FeatureSupportsUserSwitch: true,
FeatureName: FeatureSSHLogin,
})
if !result.Allowed {
return result, result.Error
}
return result, nil
}
// isPrivilegedUsername checks if the given username represents a privileged user across platforms.
// On Unix: root
// On Windows: Administrator, SYSTEM (case-insensitive)
// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com"
func isPrivilegedUsername(username string) bool {
if getCurrentOS() != "windows" {
return username == "root"
}
bareUsername := username
// Handle Windows domain format: DOMAIN\username
if idx := strings.LastIndex(username, `\`); idx != -1 {
bareUsername = username[idx+1:]
}
// Handle email-style format: username@domain.com
if idx := strings.Index(bareUsername, "@"); idx != -1 {
bareUsername = bareUsername[:idx]
}
return isWindowsPrivilegedUser(bareUsername)
}
// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account
func isWindowsPrivilegedUser(bareUsername string) bool {
// common privileged usernames (case insensitive)
privilegedNames := []string{
"administrator",
"admin",
"root",
"system",
"localsystem",
"networkservice",
"localservice",
}
usernameLower := strings.ToLower(bareUsername)
for _, privilegedName := range privilegedNames {
if usernameLower == privilegedName {
return true
}
}
// computer accounts (ending with $) are not privileged by themselves
// They only gain privileges through group membership or specific SIDs
if targetUser, err := lookupUser(bareUsername); err == nil {
return isWindowsPrivilegedSID(targetUser.Uid)
}
return false
}
// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account
func isWindowsPrivilegedSID(sid string) bool {
privilegedSIDs := []string{
"S-1-5-18", // Local System (SYSTEM)
"S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE)
"S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE)
"S-1-5-32-544", // Administrators group (BUILTIN\Administrators)
"S-1-5-500", // Built-in Administrator account (local machine RID 500)
}
for _, privilegedSID := range privilegedSIDs {
if sid == privilegedSID {
return true
}
}
// Check for domain administrator accounts (RID 500 in any domain)
// Format: S-1-5-21-domain-domain-domain-500
// This is reliable as RID 500 is reserved for the domain Administrator account
if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") {
return true
}
// Check for other well-known privileged RIDs in domain contexts
// RID 512 = Domain Admins group, RID 516 = Domain Controllers group
if strings.HasPrefix(sid, "S-1-5-21-") {
if strings.HasSuffix(sid, "-512") || // Domain Admins group
strings.HasSuffix(sid, "-516") || // Domain Controllers group
strings.HasSuffix(sid, "-519") { // Enterprise Admins group
return true
}
}
return false
}
// isCurrentProcessPrivileged checks if the current process is running with elevated privileges.
// On Unix systems, this means running as root (UID 0).
// On Windows, this means running as Administrator or SYSTEM.
func isCurrentProcessPrivileged() bool {
if getCurrentOS() == "windows" {
return isWindowsElevated()
}
return getEuid() == 0
}
// isWindowsElevated checks if the current process is running with elevated privileges on Windows
func isWindowsElevated() bool {
currentUser, err := getCurrentUser()
if err != nil {
log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err)
return false
}
if isWindowsPrivilegedSID(currentUser.Uid) {
log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid)
return true
}
if isPrivilegedUsername(currentUser.Username) {
log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username)
return true
}
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
return false
}

View File

@@ -0,0 +1,8 @@
//go:build js
package server
// validateUsername is not supported on JS/WASM
func validateUsername(_ string) error {
return errNotSupported
}

View File

@@ -0,0 +1,908 @@
package server
import (
"errors"
"os/user"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
)
// Test helper functions
func createTestUser(username, uid, gid, homeDir string) *user.User {
return &user.User{
Uid: uid,
Gid: gid,
Username: username,
Name: username,
HomeDir: homeDir,
}
}
// Test dependency injection setup - injects platform dependencies to test real logic
func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() {
// Store originals
originalGetCurrentUser := getCurrentUser
originalLookupUser := lookupUser
originalGetCurrentOS := getCurrentOS
originalGetEuid := getEuid
// Reset caches to ensure clean test state
// Set test values - inject platform dependencies
getCurrentUser = func() (*user.User, error) {
return currentUser, currentUserErr
}
lookupUser = func(username string) (*user.User, error) {
if err, exists := lookupErrors[username]; exists {
return nil, err
}
if userObj, exists := lookupUsers[username]; exists {
return userObj, nil
}
return nil, errors.New("user: unknown user " + username)
}
getCurrentOS = func() string {
return os
}
getEuid = func() int {
return euid
}
// Mock privilege detection based on the test user
getIsProcessPrivileged = func() bool {
if currentUser == nil {
return false
}
// Check both username and SID for Windows systems
if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) {
return true
}
return isPrivilegedUsername(currentUser.Username)
}
// Return cleanup function
return func() {
getCurrentUser = originalGetCurrentUser
lookupUser = originalLookupUser
getCurrentOS = originalGetCurrentOS
getEuid = originalGetEuid
getIsProcessPrivileged = isCurrentProcessPrivileged
// Reset caches after test
}
}
func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) {
tests := []struct {
name string
os string
euid int
currentUser *user.User
requestedUsername string
featureSupportsUserSwitch bool
allowRoot bool
lookupUsers map[string]*user.User
expectedAllowed bool
expectedRequiresSwitch bool
}{
{
name: "linux_root_can_switch_to_alice",
os: "linux",
euid: 0, // Root process
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "alice",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
},
expectedAllowed: true,
expectedRequiresSwitch: true,
},
{
name: "linux_non_root_fallback_to_current_user",
os: "linux",
euid: 1000, // Non-root process
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "bob",
featureSupportsUserSwitch: true,
allowRoot: true,
expectedAllowed: true, // Should fallback to current user (alice)
expectedRequiresSwitch: false, // Fallback means no actual switching
},
{
name: "windows_admin_can_switch_to_alice",
os: "windows",
euid: 1000, // Irrelevant on Windows
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "alice",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
},
expectedAllowed: true,
expectedRequiresSwitch: true,
},
{
name: "windows_non_admin_no_fallback_hard_failure",
os: "windows",
euid: 1000, // Irrelevant on Windows
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
requestedUsername: "bob",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"),
},
expectedAllowed: true, // Let OS decide - deferred security check
expectedRequiresSwitch: true, // Different user was requested
},
// Comprehensive test matrix: non-root linux with different allowRoot settings
{
name: "linux_non_root_request_root_allowRoot_false",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: true, // Fallback allows access regardless of root setting
expectedRequiresSwitch: false, // Fallback case, no switching
},
{
name: "linux_non_root_request_root_allowRoot_true",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
expectedAllowed: true, // Should fallback to alice (non-privileged process)
expectedRequiresSwitch: false, // Fallback means no actual switching
},
// Windows admin test matrix
{
name: "windows_admin_request_root_allowRoot_false",
os: "windows",
euid: 1000,
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: false, // Root not allowed
expectedRequiresSwitch: true,
},
{
name: "windows_admin_request_root_allowRoot_true",
os: "windows",
euid: 1000,
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // Windows user switching should work like Unix
expectedRequiresSwitch: true,
},
// Windows non-admin test matrix
{
name: "windows_non_admin_request_root_allowRoot_false",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence)
expectedRequiresSwitch: true,
},
{
name: "windows_system_account_allowRoot_false",
os: "windows",
euid: 1000,
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: false, // Root not allowed
expectedRequiresSwitch: true,
},
{
name: "windows_system_account_allowRoot_true",
os: "windows",
euid: 1000,
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // SYSTEM can switch to root
expectedRequiresSwitch: true,
},
{
name: "windows_non_admin_request_root_allowRoot_true",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "root",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // Let OS decide - deferred security check
expectedRequiresSwitch: true,
},
// Feature doesn't support user switching scenarios
{
name: "linux_root_feature_no_user_switching_same_user",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "root", // Same user
featureSupportsUserSwitch: false,
allowRoot: true,
lookupUsers: map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
},
expectedAllowed: true, // Same user should work regardless of feature support
expectedRequiresSwitch: false,
},
{
name: "linux_root_feature_no_user_switching_different_user",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "alice",
featureSupportsUserSwitch: false, // Feature doesn't support switching
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
},
expectedAllowed: false, // Should deny because feature doesn't support switching
expectedRequiresSwitch: true,
},
// Empty username (current user) scenarios
{
name: "linux_non_root_current_user_empty_username",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
requestedUsername: "", // Empty = current user
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: true, // Current user should always work
expectedRequiresSwitch: false,
},
{
name: "linux_root_current_user_empty_username_root_not_allowed",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "", // Empty = current user (root)
featureSupportsUserSwitch: true,
allowRoot: false, // Root not allowed
expectedAllowed: false, // Should deny root even when it's current user
expectedRequiresSwitch: false,
},
// User not found scenarios
{
name: "linux_root_user_not_found",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
requestedUsername: "nonexistent",
featureSupportsUserSwitch: true,
allowRoot: true,
lookupUsers: map[string]*user.User{}, // No users defined = user not found
expectedAllowed: false, // Should fail due to user not found
expectedRequiresSwitch: true,
},
// Windows feature doesn't support user switching
{
name: "windows_admin_feature_no_user_switching_different_user",
os: "windows",
euid: 1000,
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
requestedUsername: "alice",
featureSupportsUserSwitch: false, // Feature doesn't support switching
allowRoot: true,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
},
expectedAllowed: false, // Should deny because feature doesn't support switching
expectedRequiresSwitch: true,
},
// Windows regular user scenarios (non-admin)
{
name: "windows_regular_user_same_user",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "alice", // Same user
featureSupportsUserSwitch: true,
allowRoot: false,
lookupUsers: map[string]*user.User{
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
},
expectedAllowed: true, // Regular user accessing themselves should work
expectedRequiresSwitch: false, // No switching for same user
},
{
name: "windows_regular_user_empty_username",
os: "windows",
euid: 1000,
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
requestedUsername: "", // Empty = current user
featureSupportsUserSwitch: true,
allowRoot: false,
expectedAllowed: true, // Current user should always work
expectedRequiresSwitch: false, // No switching for current user
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Inject platform dependencies to test real logic
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil)
defer cleanup()
server := &Server{allowRootLogin: tt.allowRoot}
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: tt.requestedUsername,
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
FeatureName: "SSH login",
})
assert.Equal(t, tt.expectedAllowed, result.Allowed)
assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching)
})
}
}
func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Fallback mechanism is Unix-specific")
}
// Create test scenario where fallback should occur
server := &Server{allowRootLogin: true}
// Mock dependencies to simulate non-privileged user
originalGetCurrentUser := getCurrentUser
originalGetIsProcessPrivileged := getIsProcessPrivileged
defer func() {
getCurrentUser = originalGetCurrentUser
getIsProcessPrivileged = originalGetIsProcessPrivileged
}()
// Set up mocks for fallback scenario
getCurrentUser = func() (*user.User, error) {
return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil
}
getIsProcessPrivileged = func() bool { return false } // Non-privileged
// Request different user - should fallback
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: "alice",
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
// Verify fallback occurred
assert.True(t, result.Allowed, "Should allow with fallback")
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
assert.Equal(t, "netbird", result.User.Username, "Should return current user")
assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used")
// Key assertion: When UsedFallback is true, no privilege dropping should be needed
// because all privilege checks have already been performed and we're using current user
t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed",
result.User.Username)
}
func TestPrivilegedUsernameDetection(t *testing.T) {
tests := []struct {
name string
username string
platform string
privileged bool
}{
// Unix/Linux tests
{"unix_root", "root", "linux", true},
{"unix_regular_user", "alice", "linux", false},
{"unix_root_capital", "Root", "linux", false}, // Case-sensitive
// Windows tests
{"windows_administrator", "Administrator", "windows", true},
{"windows_system", "SYSTEM", "windows", true},
{"windows_admin", "admin", "windows", true},
{"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive
{"windows_domain_admin", "DOMAIN\\Administrator", "windows", true},
{"windows_email_admin", "admin@domain.com", "windows", true},
{"windows_regular_user", "alice", "windows", false},
{"windows_domain_user", "DOMAIN\\alice", "windows", false},
{"windows_localsystem", "localsystem", "windows", true},
{"windows_networkservice", "networkservice", "windows", true},
{"windows_localservice", "localservice", "windows", true},
// Computer accounts (these depend on current user context in real implementation)
{"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged
{"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account
// Cross-platform
{"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock the platform for this test
cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil)
defer cleanup()
result := isPrivilegedUsername(tt.username)
assert.Equal(t, tt.privileged, result)
})
}
}
func TestWindowsPrivilegedSIDDetection(t *testing.T) {
tests := []struct {
name string
sid string
privileged bool
description string
}{
// Well-known system accounts
{"system_account", "S-1-5-18", true, "Local System (SYSTEM)"},
{"local_service", "S-1-5-19", true, "Local Service"},
{"network_service", "S-1-5-20", true, "Network Service"},
{"administrators_group", "S-1-5-32-544", true, "Administrators group"},
{"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"},
// Domain accounts
{"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"},
{"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"},
{"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"},
{"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"},
// Regular users
{"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"},
{"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"},
{"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"},
// Groups that are not privileged
{"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"},
{"power_users", "S-1-5-32-547", false, "Power Users group"},
// Invalid SIDs
{"malformed_sid", "S-1-5-invalid", false, "Malformed SID"},
{"empty_sid", "", false, "Empty SID"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isWindowsPrivilegedSID(tt.sid)
assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid)
})
}
}
func TestIsSameUser(t *testing.T) {
tests := []struct {
name string
user1 string
user2 string
os string
expected bool
}{
// Basic cases
{"same_username", "alice", "alice", "linux", true},
{"different_username", "alice", "bob", "linux", false},
// Linux (no domain processing)
{"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false},
{"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false},
{"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true},
// Windows (with domain processing) - Note: parameter order is (requested, current, os, expected)
{"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user
{"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user
{"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users
{"windows_case_insensitive", "Alice", "alice", "windows", true},
{"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set up OS mock
cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil)
defer cleanup()
result := isSameUser(tt.user1, tt.user2)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUsernameValidation_Unix(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Unix-specific username validation tests")
}
tests := []struct {
name string
username string
wantErr bool
errMsg string
}{
// Valid usernames (Unix/POSIX)
{"valid_alphanumeric", "user123", false, ""},
{"valid_with_dots", "user.name", false, ""},
{"valid_with_hyphens", "user-name", false, ""},
{"valid_with_underscores", "user_name", false, ""},
{"valid_uppercase", "UserName", false, ""},
{"valid_starting_with_digit", "123user", false, ""},
{"valid_starting_with_dot", ".hidden", false, ""},
// Invalid usernames (Unix/POSIX)
{"empty_username", "", true, "username cannot be empty"},
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
{"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction
{"username_with_spaces", "user name", true, "invalid characters"},
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
{"username_with_pipe", "user|rm", true, "invalid characters"},
{"username_with_ampersand", "user&rm", true, "invalid characters"},
{"username_with_quotes", "user\"name", true, "invalid characters"},
{"username_with_newline", "user\nname", true, "invalid characters"},
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
{"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames
{"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsername(tt.username)
if tt.wantErr {
assert.Error(t, err, "Should reject invalid username")
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
}
} else {
assert.NoError(t, err, "Should accept valid username")
}
})
}
}
func TestUsernameValidation_Windows(t *testing.T) {
if runtime.GOOS != "windows" {
t.Skip("Windows-specific username validation tests")
}
tests := []struct {
name string
username string
wantErr bool
errMsg string
}{
// Valid usernames (Windows)
{"valid_alphanumeric", "user123", false, ""},
{"valid_with_dots", "user.name", false, ""},
{"valid_with_hyphens", "user-name", false, ""},
{"valid_with_underscores", "user_name", false, ""},
{"valid_uppercase", "UserName", false, ""},
{"valid_starting_with_digit", "123user", false, ""},
{"valid_starting_with_dot", ".hidden", false, ""},
{"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this
{"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format
{"valid_email_username", "user@domain.com", false, ""}, // Windows email format
{"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format
// Invalid usernames (Windows)
{"empty_username", "", true, "username cannot be empty"},
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
{"username_with_spaces", "user name", true, "invalid characters"},
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
{"username_with_pipe", "user|rm", true, "invalid characters"},
{"username_with_ampersand", "user&rm", true, "invalid characters"},
{"username_with_quotes", "user\"name", true, "invalid characters"},
{"username_with_newline", "user\nname", true, "invalid characters"},
{"username_with_brackets", "user[name]", true, "invalid characters"},
{"username_with_colon", "user:name", true, "invalid characters"},
{"username_with_semicolon", "user;name", true, "invalid characters"},
{"username_with_equals", "user=name", true, "invalid characters"},
{"username_with_comma", "user,name", true, "invalid characters"},
{"username_with_plus", "user+name", true, "invalid characters"},
{"username_with_asterisk", "user*name", true, "invalid characters"},
{"username_with_question", "user?name", true, "invalid characters"},
{"username_with_angles", "user<name>", true, "invalid characters"},
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
{"username_ending_with_period", "user.", true, "cannot end with a period"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateUsername(tt.username)
if tt.wantErr {
assert.Error(t, err, "Should reject invalid username")
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
}
} else {
assert.NoError(t, err, "Should accept valid username")
}
})
}
}
// Test real-world integration scenarios with actual platform capabilities
func TestCheckPrivileges_RealWorldScenarios(t *testing.T) {
tests := []struct {
name string
feature string
featureSupportsUserSwitch bool
requestedUsername string
allowRoot bool
expectedBehaviorPattern string
}{
{"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"},
{"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"},
{"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"},
{"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"},
{"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock privileged environment to ensure consistent test behavior across environments
cleanup := setupTestDependencies(
createTestUser("root", "0", "0", "/root"), // Running as root
nil,
runtime.GOOS,
0, // euid 0 (root)
map[string]*user.User{
"root": createTestUser("root", "0", "0", "/root"),
"differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"),
},
nil,
)
defer cleanup()
server := &Server{allowRootLogin: tt.allowRoot}
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: tt.requestedUsername,
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
FeatureName: tt.feature,
})
switch tt.expectedBehaviorPattern {
case "should_allow_current_user":
assert.True(t, result.Allowed, "Should allow current user access")
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
case "should_deny_root":
assert.False(t, result.Allowed, "Should deny root when not allowed")
assert.Contains(t, result.Error.Error(), "root", "Should mention root in error")
case "should_deny_switching":
assert.False(t, result.Allowed, "Should deny when feature doesn't support switching")
assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error")
}
})
}
}
// Test with actual platform capabilities - no mocking
func TestCheckPrivileges_ActualPlatform(t *testing.T) {
// This test uses the REAL platform capabilities
server := &Server{allowRootLogin: true}
// Test current user access - should always work
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: "", // Current user
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
assert.True(t, result.Allowed, "Current user should always be allowed")
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
assert.NotNil(t, result.User, "Should return current user")
// Test user switching capability based on actual platform
actualIsPrivileged := isCurrentProcessPrivileged() // REAL check
actualOS := runtime.GOOS // REAL check
t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v",
actualOS, actualIsPrivileged, actualIsPrivileged)
// Test requesting different user
result = server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: "nonexistentuser",
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
switch {
case actualOS == "windows":
// Windows supports user switching but should fail on nonexistent user
assert.False(t, result.Allowed, "Windows should deny nonexistent user")
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
assert.Contains(t, result.Error.Error(), "not found",
"Should indicate user not found")
case !actualIsPrivileged:
// Non-privileged Unix processes should fallback to current user
assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
assert.NotNil(t, result.User, "Should return current user")
default:
// Privileged Unix processes should attempt user lookup
assert.False(t, result.Allowed, "Should fail due to nonexistent user")
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
assert.Contains(t, result.Error.Error(), "nonexistentuser",
"Should indicate user not found")
}
}
// Test platform detection logic with dependency injection
func TestPlatformLogic_DependencyInjection(t *testing.T) {
tests := []struct {
name string
os string
euid int
currentUser *user.User
expectedIsProcessPrivileged bool
expectedSupportsUserSwitching bool
}{
{
name: "linux_root_process",
os: "linux",
euid: 0,
currentUser: createTestUser("root", "0", "0", "/root"),
expectedIsProcessPrivileged: true,
expectedSupportsUserSwitching: true,
},
{
name: "linux_non_root_process",
os: "linux",
euid: 1000,
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
expectedIsProcessPrivileged: false,
expectedSupportsUserSwitching: false,
},
{
name: "windows_admin_process",
os: "windows",
euid: 1000, // euid ignored on Windows
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
expectedIsProcessPrivileged: true,
expectedSupportsUserSwitching: true, // Windows supports user switching when privileged
},
{
name: "windows_regular_process",
os: "windows",
euid: 1000, // euid ignored on Windows
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
expectedIsProcessPrivileged: false,
expectedSupportsUserSwitching: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Inject platform dependencies and test REAL logic
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil)
defer cleanup()
// Test the actual functions with injected dependencies
actualIsPrivileged := isCurrentProcessPrivileged()
actualSupportsUserSwitching := actualIsPrivileged
assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged,
"isCurrentProcessPrivileged() result mismatch")
assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching,
"supportsUserSwitching() result mismatch")
t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username)
t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v",
actualIsPrivileged, actualSupportsUserSwitching)
})
}
}
func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) {
// Test Windows elevated user switching scenarios with simplified privilege logic
tests := []struct {
name string
currentUser *user.User
requestedUsername string
allowRoot bool
expectedAllowed bool
expectedErrorContains string
}{
{
name: "windows_admin_can_switch_to_alice",
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
requestedUsername: "alice",
allowRoot: true,
expectedAllowed: true,
},
{
name: "windows_non_admin_can_try_switch",
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"),
requestedUsername: "bob",
allowRoot: true,
expectedAllowed: true, // Privilege check allows it, OS will reject during execution
},
{
name: "windows_system_can_switch_to_alice",
currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"),
requestedUsername: "alice",
allowRoot: true,
expectedAllowed: true,
},
{
name: "windows_admin_root_not_allowed",
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
requestedUsername: "root",
allowRoot: false,
expectedAllowed: false,
expectedErrorContains: "privileged user login is disabled",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup test dependencies with Windows OS and specified privileges
lookupUsers := map[string]*user.User{
tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername),
}
cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil)
defer cleanup()
server := &Server{allowRootLogin: tt.allowRoot}
result := server.CheckPrivileges(PrivilegeCheckRequest{
RequestedUsername: tt.requestedUsername,
FeatureSupportsUserSwitch: true,
FeatureName: "SSH login",
})
assert.Equal(t, tt.expectedAllowed, result.Allowed,
"Privilege check result should match expected for %s", tt.name)
if !tt.expectedAllowed && tt.expectedErrorContains != "" {
assert.NotNil(t, result.Error, "Should have error when not allowed")
assert.Contains(t, result.Error.Error(), tt.expectedErrorContains,
"Error should contain expected message")
}
if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername {
assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user")
}
})
}
}

View File

@@ -0,0 +1,8 @@
//go:build js
package server
// enableUserSwitching is not supported on JS/WASM
func enableUserSwitching() error {
return errNotSupported
}

View File

@@ -0,0 +1,233 @@
//go:build unix
package server
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/exec"
"os/user"
"regexp"
"runtime"
"strconv"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
// POSIX portable filename character set regex: [a-zA-Z0-9._-]
// First character cannot be hyphen (POSIX requirement)
var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`)
// validateUsername validates that a username conforms to POSIX standards with security considerations
func validateUsername(username string) error {
if username == "" {
return errors.New("username cannot be empty")
}
// POSIX allows up to 256 characters, but practical limit is 32 for compatibility
if len(username) > 32 {
return errors.New("username too long (max 32 characters)")
}
if !posixUsernameRegex.MatchString(username) {
return errors.New("username contains invalid characters (must match POSIX portable filename character set)")
}
if username == "." || username == ".." {
return fmt.Errorf("username cannot be '.' or '..'")
}
// Warn if username is fully numeric (can cause issues with UID/username ambiguity)
if isFullyNumeric(username) {
log.Warnf("fully numeric username '%s' may cause issues with some commands", username)
}
return nil
}
// isFullyNumeric checks if username contains only digits
func isFullyNumeric(username string) bool {
for _, char := range username {
if char < '0' || char > '9' {
return false
}
}
return true
}
// createPtyLoginCommand creates a Pty command using login for privileged processes
func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr())
if err != nil {
return nil, fmt.Errorf("get login command: %w", err)
}
execCmd := exec.CommandContext(session.Context(), loginPath, args...)
execCmd.Dir = localUser.HomeDir
execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
return execCmd, nil
}
// getLoginCmd returns the login command and args for privileged Pty user switching
func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) {
loginPath, err := exec.LookPath("login")
if err != nil {
return "", nil, fmt.Errorf("login command not available: %w", err)
}
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
if err != nil {
return "", nil, fmt.Errorf("parse remote address: %w", err)
}
switch runtime.GOOS {
case "linux":
// Special handling for Arch Linux without /etc/pam.d/remote
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
return loginPath, []string{"-f", username, "-p"}, nil
}
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
default:
return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS)
}
}
// fileExists checks if a file exists (helper for login command logic)
func (s *Server) fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
// parseUserCredentials extracts numeric UID, GID, and supplementary groups
func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) {
uid64, err := strconv.ParseUint(localUser.Uid, 10, 32)
if err != nil {
return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err)
}
uid := uint32(uid64)
gid64, err := strconv.ParseUint(localUser.Gid, 10, 32)
if err != nil {
return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err)
}
gid := uint32(gid64)
groups, err := s.getSupplementaryGroups(localUser.Username)
if err != nil {
log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err)
groups = []uint32{gid}
}
return uid, gid, groups, nil
}
// getSupplementaryGroups retrieves supplementary group IDs for a user
func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
u, err := user.Lookup(username)
if err != nil {
return nil, fmt.Errorf("lookup user %s: %w", username, err)
}
groupIDStrings, err := u.GroupIds()
if err != nil {
return nil, fmt.Errorf("get group IDs for user %s: %w", username, err)
}
groups := make([]uint32, len(groupIDStrings))
for i, gidStr := range groupIDStrings {
gid64, err := strconv.ParseUint(gidStr, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err)
}
groups[i] = uint32(gid64)
}
return groups, nil
}
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
// Returns the command and a cleanup function (no-op on Unix).
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
if err := validateUsername(localUser.Username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
}
uid, gid, groups, err := s.parseUserCredentials(localUser)
if err != nil {
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
}
privilegeDropper := NewPrivilegeDropper()
config := ExecutorConfig{
UID: uid,
GID: gid,
Groups: groups,
WorkingDir: localUser.HomeDir,
Shell: getUserShell(localUser.Uid),
Command: session.RawCommand(),
PTY: hasPty,
}
cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config)
return cmd, func() {}, err
}
// enableUserSwitching is a no-op on Unix systems
func enableUserSwitching() error {
return nil
}
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
localUser := privilegeResult.User
if localUser == nil {
return nil, errors.New("no user in privilege result")
}
if privilegeResult.UsedFallback {
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
}
return s.createPtyLoginCommand(localUser, ptyReq, session)
}
// createDirectPtyCommand creates a direct Pty command without privilege dropping
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
shell := getUserShell(localUser.Uid)
args := s.getShellCommandArgs(shell, session.RawCommand())
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
cmd.Dir = localUser.HomeDir
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
return cmd
}
// preparePtyEnv prepares environment variables for Pty execution
func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string {
termType := ptyReq.Term
if termType == "" {
termType = "xterm-256color"
}
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
env = append(env, prepareSSHEnv(session)...)
env = append(env, fmt.Sprintf("TERM=%s", termType))
for _, v := range session.Environ() {
if acceptEnv(v) {
env = append(env, v)
}
}
return env
}

View File

@@ -0,0 +1,274 @@
//go:build windows
package server
import (
"errors"
"fmt"
"os/exec"
"os/user"
"strings"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
// validateUsername validates Windows usernames according to SAM Account Name rules
func validateUsername(username string) error {
if username == "" {
return fmt.Errorf("username cannot be empty")
}
usernameToValidate := extractUsernameFromDomain(username)
if err := validateUsernameLength(usernameToValidate); err != nil {
return err
}
if err := validateUsernameCharacters(usernameToValidate); err != nil {
return err
}
if err := validateUsernameFormat(usernameToValidate); err != nil {
return err
}
return nil
}
// extractUsernameFromDomain extracts the username part from domain\username or username@domain format
func extractUsernameFromDomain(username string) string {
if idx := strings.LastIndex(username, `\`); idx != -1 {
return username[idx+1:]
}
if idx := strings.Index(username, "@"); idx != -1 {
return username[:idx]
}
return username
}
// validateUsernameLength checks if username length is within Windows limits
func validateUsernameLength(username string) error {
if len(username) > 20 {
return fmt.Errorf("username too long (max 20 characters for Windows)")
}
return nil
}
// validateUsernameCharacters checks for invalid characters in Windows usernames
func validateUsernameCharacters(username string) error {
invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'}
for _, char := range username {
for _, invalid := range invalidChars {
if char == invalid {
return fmt.Errorf("username contains invalid characters")
}
}
if char < 32 || char == 127 {
return fmt.Errorf("username contains control characters")
}
}
return nil
}
// validateUsernameFormat checks for invalid username formats and patterns
func validateUsernameFormat(username string) error {
if username == "." || username == ".." {
return fmt.Errorf("username cannot be '.' or '..'")
}
if strings.HasSuffix(username, ".") {
return fmt.Errorf("username cannot end with a period")
}
return nil
}
// createExecutorCommand creates a command using Windows executor for privilege dropping.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
username, _ := s.parseUsername(localUser.Username)
if err := validateUsername(username); err != nil {
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
}
return s.createUserSwitchCommand(localUser, session, hasPty)
}
// createUserSwitchCommand creates a command with Windows user switching.
// Returns the command and a cleanup function that must be called after starting the process.
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
username, domain := s.parseUsername(localUser.Username)
shell := getUserShell(localUser.Uid)
rawCmd := session.RawCommand()
var command string
if rawCmd != "" {
command = rawCmd
}
config := WindowsExecutorConfig{
Username: username,
Domain: domain,
WorkingDir: localUser.HomeDir,
Shell: shell,
Command: command,
Interactive: interactive || (rawCmd == ""),
}
dropper := NewPrivilegeDropper()
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
if err != nil {
return nil, nil, err
}
cleanup := func() {
if token != 0 {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("close primary token: %v", err)
}
}
}
return cmd, cleanup, nil
}
// parseUsername extracts username and domain from a Windows username
func (s *Server) parseUsername(fullUsername string) (username, domain string) {
// Handle DOMAIN\username format
if idx := strings.LastIndex(fullUsername, `\`); idx != -1 {
domain = fullUsername[:idx]
username = fullUsername[idx+1:]
return username, domain
}
// Handle username@domain format
if username, domain, ok := strings.Cut(fullUsername, "@"); ok {
return username, domain
}
// Local user (no domain)
return fullUsername, "."
}
// hasPrivilege checks if the current process has a specific privilege
func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) {
var luid windows.LUID
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
return false, fmt.Errorf("lookup privilege value: %w", err)
}
var returnLength uint32
err := windows.GetTokenInformation(
windows.Token(token),
windows.TokenPrivileges,
nil, // null buffer to get size
0,
&returnLength,
)
if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
return false, fmt.Errorf("get token information size: %w", err)
}
buffer := make([]byte, returnLength)
err = windows.GetTokenInformation(
windows.Token(token),
windows.TokenPrivileges,
&buffer[0],
returnLength,
&returnLength,
)
if err != nil {
return false, fmt.Errorf("get token information: %w", err)
}
privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0]))
// Check if the privilege is present and enabled
for i := uint32(0); i < privileges.PrivilegeCount; i++ {
privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer(
uintptr(unsafe.Pointer(&privileges.Privileges[0])) +
uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}),
))
if privilege.Luid == luid {
return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil
}
}
return false, nil
}
// enablePrivilege enables a specific privilege for the current process token
// This is required because privileges like SeAssignPrimaryTokenPrivilege are present
// but disabled by default, even for the SYSTEM account
func enablePrivilege(token windows.Handle, privilegeName string) error {
var luid windows.LUID
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err)
}
privileges := windows.Tokenprivileges{
PrivilegeCount: 1,
Privileges: [1]windows.LUIDAndAttributes{
{
Luid: luid,
Attributes: windows.SE_PRIVILEGE_ENABLED,
},
},
}
err := windows.AdjustTokenPrivileges(
windows.Token(token),
false,
&privileges,
0,
nil,
nil,
)
if err != nil {
return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err)
}
hasPriv, err := hasPrivilege(token, privilegeName)
if err != nil {
return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err)
}
if !hasPriv {
return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName)
}
log.Debugf("Successfully enabled privilege %s for current process", privilegeName)
return nil
}
// enableUserSwitching enables required privileges for Windows user switching
func enableUserSwitching() error {
process := windows.CurrentProcess()
var token windows.Token
err := windows.OpenProcessToken(
process,
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY,
&token,
)
if err != nil {
return fmt.Errorf("open process token: %w", err)
}
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
log.Debugf("Failed to close process token: %v", err)
}
}()
if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil {
return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err)
}
log.Infof("Windows user switching privileges enabled successfully")
return nil
}

View File

@@ -0,0 +1,487 @@
//go:build windows
package winpty
import (
"context"
"errors"
"fmt"
"io"
"strings"
"sync"
"syscall"
"unsafe"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
ErrEmptyEnvironment = errors.New("empty environment")
)
const (
extendedStartupInfoPresent = 0x00080000
createUnicodeEnvironment = 0x00000400
procThreadAttributePseudoConsole = 0x00020016
PowerShellCommandFlag = "-Command"
errCloseInputRead = "close input read handle: %v"
errCloseConPtyCleanup = "close ConPty handle during cleanup"
)
// PtyConfig holds configuration for Pty execution.
type PtyConfig struct {
Shell string
Command string
Width int
Height int
WorkingDir string
}
// UserConfig holds user execution configuration.
type UserConfig struct {
Token windows.Handle
Environment []string
}
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
)
// ExecutePtyWithUserToken executes a command with ConPty using user token.
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
commandLine := buildCommandLine(args)
config := ExecutionConfig{
Pty: ptyConfig,
User: userConfig,
Session: session,
Context: ctx,
}
return executeConPtyWithConfig(commandLine, config)
}
// ExecutionConfig holds all execution configuration.
type ExecutionConfig struct {
Pty PtyConfig
User UserConfig
Session ssh.Session
Context context.Context
}
// executeConPtyWithConfig creates ConPty and executes process with configuration.
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
ctx := config.Context
session := config.Session
width := config.Pty.Width
height := config.Pty.Height
userToken := config.User.Token
userEnv := config.User.Environment
workingDir := config.Pty.WorkingDir
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
if err != nil {
return fmt.Errorf("create ConPty pipes: %w", err)
}
hPty, err := createConPty(width, height, inputRead, outputWrite)
if err != nil {
return fmt.Errorf("create ConPty: %w", err)
}
primaryToken, err := duplicateToPrimaryToken(userToken)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("duplicate to primary token: %w", err)
}
defer func() {
if err := windows.CloseHandle(primaryToken); err != nil {
log.Debugf("close primary token: %v", err)
}
}()
siEx, err := setupConPtyStartupInfo(hPty)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("setup startup info: %w", err)
}
defer func() {
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
}()
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
if err != nil {
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
log.Debugf(errCloseConPtyCleanup)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
return fmt.Errorf("create process as user with ConPty: %w", err)
}
defer closeProcessInfo(pi)
if err := windows.CloseHandle(inputRead); err != nil {
log.Debugf(errCloseInputRead, err)
}
if err := windows.CloseHandle(outputWrite); err != nil {
log.Debugf("close output write handle: %v", err)
}
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process)
}
// createConPtyPipes creates input/output pipes for ConPty.
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
}
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
log.Debugf(errCloseInputRead, closeErr)
}
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
log.Debugf("close input write handle: %v", closeErr)
}
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
}
return inputRead, inputWrite, outputRead, outputWrite, nil
}
// createConPty creates a Windows ConPty with the specified size and pipe handles.
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
size := windows.Coord{X: int16(width), Y: int16(height)}
var hPty windows.Handle
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
}
return hPty, nil
}
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
var siEx windows.StartupInfoEx
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
var attrListSize uintptr
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
if ret == 0 && attrListSize == 0 {
return nil, fmt.Errorf("get attribute list size")
}
attrListBytes := make([]byte, attrListSize)
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
ret, _, err := procInitializeProcThreadAttributeList.Call(
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
1,
0,
uintptr(unsafe.Pointer(&attrListSize)),
)
if ret == 0 {
return nil, fmt.Errorf("initialize attribute list: %w", err)
}
ret, _, err = procUpdateProcThreadAttribute.Call(
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
0,
procThreadAttributePseudoConsole,
uintptr(hPty),
unsafe.Sizeof(hPty),
0,
0,
)
if ret == 0 {
return nil, fmt.Errorf("update thread attribute: %w", err)
}
return &siEx, nil
}
// createConPtyProcess creates the actual process with ConPty.
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
var pi windows.ProcessInformation
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
if err != nil {
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
}
envPtr, err := convertEnvironmentToUTF16(userEnv)
if err != nil {
return nil, err
}
var workingDirPtr *uint16
if workingDir != "" {
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
if err != nil {
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
}
}
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
siEx.StartupInfo.StdInput = windows.Handle(0)
siEx.StartupInfo.StdOutput = windows.Handle(0)
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
if userToken != windows.InvalidHandle {
err = windows.CreateProcessAsUser(
windows.Token(userToken),
nil,
commandLinePtr,
nil,
nil,
true,
creationFlags,
envPtr,
workingDirPtr,
&siEx.StartupInfo,
&pi,
)
} else {
err = windows.CreateProcess(
nil,
commandLinePtr,
nil,
nil,
true,
creationFlags,
envPtr,
workingDirPtr,
&siEx.StartupInfo,
&pi,
)
}
if err != nil {
return nil, fmt.Errorf("create process: %w", err)
}
return &pi, nil
}
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
if len(userEnv) == 0 {
// Return nil pointer for empty environment - Windows API will inherit parent environment
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
}
var envUTF16 []uint16
for _, envVar := range userEnv {
if envVar != "" {
utf16Str, err := windows.UTF16FromString(envVar)
if err != nil {
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
continue
}
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
envUTF16 = append(envUTF16, 0)
}
}
envUTF16 = append(envUTF16, 0)
if len(envUTF16) > 0 {
return &envUTF16[0], nil
}
// Return nil pointer when no valid environment variables found
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
}
// duplicateToPrimaryToken converts an impersonation token to a primary token.
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
var primaryToken windows.Handle
if err := windows.DuplicateTokenEx(
windows.Token(token),
windows.TOKEN_ALL_ACCESS,
nil,
windows.SecurityImpersonation,
windows.TokenPrimary,
(*windows.Token)(&primaryToken),
); err != nil {
return 0, fmt.Errorf("duplicate token: %w", err)
}
return primaryToken, nil
}
// SessionExiter provides the Exit method for reporting process exit status.
type SessionExiter interface {
Exit(code int) error
}
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error {
if err := ctx.Err(); err != nil {
return err
}
var wg sync.WaitGroup
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
processErr := waitForProcess(ctx, process)
if processErr != nil {
return processErr
}
var exitCode uint32
if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
log.Debugf("get exit code: %v", err)
} else {
if err := session.Exit(int(exitCode)); err != nil {
log.Debugf("report exit code: %v", err)
}
}
// Clean up in the original order after process completes
if err := reader.Close(); err != nil {
log.Debugf("close reader: %v", err)
}
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
if ret == 0 {
log.Debugf("close ConPty handle: %v", err)
}
wg.Wait()
if err := windows.CloseHandle(outputRead); err != nil {
log.Debugf("close output read handle: %v", err)
}
return nil
}
// startIOBridging starts the I/O bridging goroutines.
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
wg.Add(2)
// Input: reader (SSH session) -> inputWrite (ConPty)
go func() {
defer wg.Done()
defer func() {
if err := windows.CloseHandle(inputWrite); err != nil {
log.Debugf("close input write handle in goroutine: %v", err)
}
}()
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
log.Debugf("input copy ended with error: %v", err)
}
}()
// Output: outputRead (ConPty) -> writer (SSH session)
go func() {
defer wg.Done()
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
log.Debugf("output copy ended with error: %v", err)
}
}()
}
// waitForProcess waits for process completion with context cancellation.
func waitForProcess(ctx context.Context, process windows.Handle) error {
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
return fmt.Errorf("wait for process %d: %w", process, err)
}
return nil
}
// buildShellArgs builds shell arguments for ConPty execution.
func buildShellArgs(shell, command string) []string {
if command != "" {
return []string{shell, PowerShellCommandFlag, command}
}
return []string{shell}
}
// buildCommandLine builds a Windows command line from arguments using proper escaping.
func buildCommandLine(args []string) string {
if len(args) == 0 {
return ""
}
var result strings.Builder
for i, arg := range args {
if i > 0 {
result.WriteString(" ")
}
result.WriteString(syscall.EscapeArg(arg))
}
return result.String()
}
// closeHandles closes multiple Windows handles.
func closeHandles(handles ...windows.Handle) {
for _, handle := range handles {
if handle != windows.InvalidHandle {
if err := windows.CloseHandle(handle); err != nil {
log.Debugf("close handle: %v", err)
}
}
}
}
// closeProcessInfo closes process and thread handles.
func closeProcessInfo(pi *windows.ProcessInformation) {
if pi != nil {
if err := windows.CloseHandle(pi.Process); err != nil {
log.Debugf("close process handle: %v", err)
}
if err := windows.CloseHandle(pi.Thread); err != nil {
log.Debugf("close thread handle: %v", err)
}
}
}
// windowsHandleReader wraps a Windows handle for reading.
type windowsHandleReader struct {
handle windows.Handle
}
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
var bytesRead uint32
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
return 0, err
}
return int(bytesRead), nil
}
func (r *windowsHandleReader) Close() error {
return windows.CloseHandle(r.handle)
}
// windowsHandleWriter wraps a Windows handle for writing.
type windowsHandleWriter struct {
handle windows.Handle
}
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
var bytesWritten uint32
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
return 0, err
}
return int(bytesWritten), nil
}
func (w *windowsHandleWriter) Close() error {
return windows.CloseHandle(w.handle)
}

View File

@@ -0,0 +1,290 @@
//go:build windows
package winpty
import (
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows"
)
func TestBuildShellArgs(t *testing.T) {
tests := []struct {
name string
shell string
command string
expected []string
}{
{
name: "Shell with command",
shell: "powershell.exe",
command: "Get-Process",
expected: []string{"powershell.exe", "-Command", "Get-Process"},
},
{
name: "CMD with command",
shell: "cmd.exe",
command: "dir",
expected: []string{"cmd.exe", "-Command", "dir"},
},
{
name: "Shell interactive",
shell: "powershell.exe",
command: "",
expected: []string{"powershell.exe"},
},
{
name: "CMD interactive",
shell: "cmd.exe",
command: "",
expected: []string{"cmd.exe"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildShellArgs(tt.shell, tt.command)
assert.Equal(t, tt.expected, result)
})
}
}
func TestBuildCommandLine(t *testing.T) {
tests := []struct {
name string
args []string
expected string
}{
{
name: "Simple args",
args: []string{"cmd.exe", "/c", "echo"},
expected: "cmd.exe /c echo",
},
{
name: "Args with spaces",
args: []string{"Program Files\\app.exe", "arg with spaces"},
expected: `"Program Files\app.exe" "arg with spaces"`,
},
{
name: "Args with quotes",
args: []string{"cmd.exe", "/c", `echo "hello world"`},
expected: `cmd.exe /c "echo \"hello world\""`,
},
{
name: "PowerShell calling PowerShell",
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
},
{
name: "Complex nested quotes",
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
},
{
name: "Path with spaces and args",
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
},
{
name: "Empty argument",
args: []string{"cmd.exe", "/c", "echo", ""},
expected: `cmd.exe /c echo ""`,
},
{
name: "Argument with backslashes",
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
expected: `robocopy C:\Source\ C:\Dest\ /E`,
},
{
name: "Empty args",
args: []string{},
expected: "",
},
{
name: "Single arg with space",
args: []string{"path with spaces"},
expected: `"path with spaces"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := buildCommandLine(tt.args)
assert.Equal(t, tt.expected, result)
})
}
}
func TestCreateConPtyPipes(t *testing.T) {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
require.NoError(t, err, "Should create ConPty pipes successfully")
// Verify all handles are valid
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
// Clean up handles
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
}
func TestCreateConPty(t *testing.T) {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
require.NoError(t, err, "Should create ConPty pipes successfully")
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
hPty, err := createConPty(80, 24, inputRead, outputWrite)
require.NoError(t, err, "Should create ConPty successfully")
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
// Clean up ConPty
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
}
func TestConvertEnvironmentToUTF16(t *testing.T) {
tests := []struct {
name string
userEnv []string
hasError bool
}{
{
name: "Valid environment variables",
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
hasError: false,
},
{
name: "Empty environment",
userEnv: []string{},
hasError: false,
},
{
name: "Environment with empty strings",
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := convertEnvironmentToUTF16(tt.userEnv)
if tt.hasError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if len(tt.userEnv) == 0 {
assert.Nil(t, result, "Empty environment should return nil")
} else {
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
}
}
})
}
}
func TestDuplicateToPrimaryToken(t *testing.T) {
if testing.Short() {
t.Skip("Skipping token tests in short mode")
}
// Get current process token for testing
var token windows.Token
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
require.NoError(t, err, "Should open current process token")
defer func() {
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
t.Logf("Failed to close token: %v", err)
}
}()
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
require.NoError(t, err, "Should duplicate token to primary")
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
// Clean up
err = windows.CloseHandle(primaryToken)
assert.NoError(t, err, "Should close primary token")
}
func TestWindowsHandleReader(t *testing.T) {
// Create a pipe for testing
var readHandle, writeHandle windows.Handle
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
require.NoError(t, err, "Should create pipe for testing")
defer closeHandles(readHandle, writeHandle)
// Write test data
testData := []byte("Hello, Windows Handle Reader!")
var bytesWritten uint32
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
require.NoError(t, err, "Should write test data")
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
// Close write handle to signal EOF
if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err)
}
writeHandle = windows.InvalidHandle
// Test reading
reader := &windowsHandleReader{handle: readHandle}
buffer := make([]byte, len(testData))
n, err := reader.Read(buffer)
require.NoError(t, err, "Should read from handle")
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
assert.Equal(t, testData, buffer, "Should read expected data")
}
func TestWindowsHandleWriter(t *testing.T) {
// Create a pipe for testing
var readHandle, writeHandle windows.Handle
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
require.NoError(t, err, "Should create pipe for testing")
defer closeHandles(readHandle, writeHandle)
// Test writing
testData := []byte("Hello, Windows Handle Writer!")
writer := &windowsHandleWriter{handle: writeHandle}
n, err := writer.Write(testData)
require.NoError(t, err, "Should write to handle")
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
// Close write handle
if err := windows.CloseHandle(writeHandle); err != nil {
t.Fatalf("Should close write handle: %v", err)
}
// Verify data was written by reading it back
buffer := make([]byte, len(testData))
var bytesRead uint32
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
require.NoError(t, err, "Should read back written data")
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
assert.Equal(t, testData, buffer, "Should read back expected data")
}
// BenchmarkConPtyCreation benchmarks ConPty creation performance
func BenchmarkConPtyCreation(b *testing.B) {
for i := 0; i < b.N; i++ {
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
if err != nil {
b.Fatal(err)
}
hPty, err := createConPty(80, 24, inputRead, outputWrite)
if err != nil {
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
b.Fatal(err)
}
// Clean up
if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 {
log.Debugf("ClosePseudoConsole failed: %v", err)
}
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
}
}

View File

@@ -1,46 +0,0 @@
//go:build !js
package ssh
import "context"
// MockServer mocks ssh.Server
type MockServer struct {
Ctx context.Context
StopFunc func() error
StartFunc func() error
AddAuthorizedKeyFunc func(peer, newKey string) error
RemoveAuthorizedKeyFunc func(peer string)
}
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
func (srv *MockServer) RemoveAuthorizedKey(peer string) {
if srv.RemoveAuthorizedKeyFunc == nil {
return
}
srv.RemoveAuthorizedKeyFunc(peer)
}
// AddAuthorizedKey add a given peer key to server authorized keys
func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error {
if srv.AddAuthorizedKeyFunc == nil {
return nil
}
return srv.AddAuthorizedKeyFunc(peer, newKey)
}
// Stop stops SSH server.
func (srv *MockServer) Stop() error {
if srv.StopFunc == nil {
return nil
}
return srv.StopFunc()
}
// Start starts SSH server. Blocking
func (srv *MockServer) Start() error {
if srv.StartFunc == nil {
return nil
}
return srv.StartFunc()
}

View File

@@ -1,123 +0,0 @@
//go:build !js
package ssh
import (
"fmt"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"strings"
"testing"
)
func TestServer_AddAuthorizedKey(t *testing.T) {
key, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
server, err := newDefaultServer(key, "localhost:")
if err != nil {
t.Fatal(err)
}
// add multiple keys
keys := map[string][]byte{}
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
remotePubKey, err := GeneratePublicKey(remotePrivKey)
if err != nil {
t.Fatal(err)
}
err = server.AddAuthorizedKey(peer, string(remotePubKey))
if err != nil {
t.Error(err)
}
keys[peer] = remotePubKey
}
// make sure that all keys have been added
for peer, remotePubKey := range keys {
k, ok := server.authorizedKeys[peer]
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k))))
}
}
func TestServer_RemoveAuthorizedKey(t *testing.T) {
key, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
server, err := newDefaultServer(key, "localhost:")
if err != nil {
t.Fatal(err)
}
remotePrivKey, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
remotePubKey, err := GeneratePublicKey(remotePrivKey)
if err != nil {
t.Fatal(err)
}
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
if err != nil {
t.Error(err)
}
server.RemoveAuthorizedKey("remotePeer")
_, ok := server.authorizedKeys["remotePeer"]
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
}
func TestServer_PubKeyHandler(t *testing.T) {
key, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
server, err := newDefaultServer(key, "localhost:")
if err != nil {
t.Fatal(err)
}
var keys []ssh.PublicKey
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
remotePubKey, err := GeneratePublicKey(remotePrivKey)
if err != nil {
t.Fatal(err)
}
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
if err != nil {
t.Fatal(err)
}
err = server.AddAuthorizedKey(peer, string(remotePubKey))
if err != nil {
t.Error(err)
}
keys = append(keys, remoteParsedPubKey)
}
for _, key := range keys {
accepted := server.publicKeyHandler(nil, key)
assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key)))
}
}

View File

@@ -32,9 +32,8 @@ const RSA KeyType = "rsa"
// RSAKeySize is a size of newly generated RSA key
const RSAKeySize = 2048
// GeneratePrivateKey creates RSA Private Key of specified byte size
// GeneratePrivateKey creates a private key of the specified type.
func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
var key crypto.Signer
var err error
switch keyType {
@@ -59,7 +58,7 @@ func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
return pemBytes, nil
}
// GeneratePublicKey returns the public part of the private key
// GeneratePublicKey returns the public part of the private key.
func GeneratePublicKey(key []byte) ([]byte, error) {
signer, err := gossh.ParsePrivateKey(key)
if err != nil {
@@ -70,20 +69,17 @@ func GeneratePublicKey(key []byte) ([]byte, error) {
return []byte(strKey), nil
}
// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format
// EncodePrivateKeyToPEM encodes a private key to PEM format.
func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) {
mk, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, err
}
// pem.Block
privBlock := pem.Block{
Type: "PRIVATE KEY",
Bytes: mk,
}
// Private key in PEM format
privatePEM := pem.EncodeToMemory(&privBlock)
return privatePEM, nil
}

View File

@@ -0,0 +1,172 @@
package testutil
import (
"fmt"
"log"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
var testCreatedUsers = make(map[string]bool)
var testUsersToCleanup []string
// GetTestUsername returns an appropriate username for testing
func GetTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
if IsSystemAccount(currentUser.Username) {
if IsCI() {
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// IsCI checks if we're running in a CI environment
func IsCI() bool {
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// IsSystemAccount checks if the user is a system account that can't authenticate
func IsSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// RegisterTestUserCleanup registers a test user for cleanup
func RegisterTestUserCleanup(username string) {
if !testCreatedUsers[username] {
testCreatedUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// CleanupTestUsers removes all created test users
func CleanupTestUsers() {
for _, username := range testUsersToCleanup {
RemoveWindowsTestUser(username)
}
testUsersToCleanup = nil
testCreatedUsers = make(map[string]bool)
}
// GetOrCreateTestUser creates a test user on Windows if needed
func GetOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
if CreateWindowsTestUser(t, testUsername) {
RegisterTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
func RemoveWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// CreateWindowsTestUser creates a local user on Windows using PowerShell
func CreateWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -1,10 +0,0 @@
//go:build freebsd
package ssh
import (
"os"
)
func setWinSize(file *os.File, width, height int) {
}

View File

@@ -1,14 +0,0 @@
//go:build linux || darwin
package ssh
import (
"os"
"syscall"
"unsafe"
)
func setWinSize(file *os.File, width, height int) {
syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0})))
}

View File

@@ -1,9 +0,0 @@
package ssh
import (
"os"
)
func setWinSize(file *os.File, width, height int) {
}

View File

@@ -81,6 +81,18 @@ type NsServerGroupStateOutput struct {
Error string `json:"error" yaml:"error"`
}
type SSHSessionOutput struct {
Username string `json:"username" yaml:"username"`
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Command string `json:"command" yaml:"command"`
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
}
type SSHServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -100,6 +112,7 @@ type OutputOverview struct {
Events []SystemEventOutput `json:"events" yaml:"events"`
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
}
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
@@ -121,6 +134,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
overview := OutputOverview{
Peers: peersOverview,
@@ -141,6 +155,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
Events: mapEvents(pbFullStatus.GetEvents()),
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: profName,
SSHServerState: sshServerOverview,
}
if anon {
@@ -190,6 +205,30 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
return mappedNSGroups
}
func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
if sshServerState == nil {
return SSHServerStateOutput{
Enabled: false,
Sessions: []SSHSessionOutput{},
}
}
sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions()))
for _, session := range sshServerState.GetSessions() {
sessions = append(sessions, SSHSessionOutput{
Username: session.GetUsername(),
RemoteAddress: session.GetRemoteAddress(),
Command: session.GetCommand(),
JWTUsername: session.GetJwtUsername(),
})
}
return SSHServerStateOutput{
Enabled: sshServerState.GetEnabled(),
Sessions: sessions,
}
}
func mapPeers(
peers []*proto.PeerState,
statusFilter string,
@@ -300,7 +339,7 @@ func ParseToYAML(overview OutputOverview) (string, error) {
return string(yamlBytes), nil
}
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
@@ -405,6 +444,41 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
lazyConnectionEnabledStatus = "true"
}
sshServerStatus := "Disabled"
if overview.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions)
if sessionCount > 0 {
sessionWord := "session"
if sessionCount > 1 {
sessionWord = "sessions"
}
sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord)
} else {
sshServerStatus = "Enabled"
}
if showSSHSessions && sessionCount > 0 {
for _, session := range overview.SSHServerState.Sessions {
var sessionDisplay string
if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
session.JWTUsername,
session.RemoteAddress,
session.Username,
session.Command,
)
} else {
sessionDisplay = fmt.Sprintf("[%s@%s] %s",
session.Username,
session.RemoteAddress,
session.Command,
)
}
sshServerStatus += "\n " + sessionDisplay
}
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
@@ -428,6 +502,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"Networks: %s\n"+
"Forwarding rules: %d\n"+
"Peers count: %s\n",
@@ -444,6 +519,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
interfaceTypeString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
networks,
overview.NumberOfForwardingRules,
peersCountString,
@@ -454,7 +530,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
func ParseToFullDetailSummary(overview OutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
parsedEventsString := parseEvents(overview.Events)
summary := ParseGeneralSummary(overview, true, true, true)
summary := ParseGeneralSummary(overview, true, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
@@ -746,4 +822,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v)
}
}
for i, session := range overview.SSHServerState.Sessions {
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
} else {
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
}
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
}
}

View File

@@ -231,6 +231,10 @@ var overview = OutputOverview{
Networks: []string{
"10.10.0.0/24",
},
SSHServerState: SSHServerStateOutput{
Enabled: false,
Sessions: []SSHSessionOutput{},
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -385,7 +389,11 @@ func TestParsingToJSON(t *testing.T) {
],
"events": [],
"lazyConnectionEnabled": false,
"profileName":""
"profileName":"",
"sshServer":{
"enabled":false,
"sessions":[]
}
}`
// @formatter:on
@@ -488,6 +496,9 @@ dnsServers:
events: []
lazyConnectionEnabled: false
profileName: ""
sshServer:
enabled: false
sessions: []
`
assert.Equal(t, expectedYAML, yaml)
@@ -554,6 +565,7 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected
@@ -563,7 +575,7 @@ Peers count: 2/2 Connected
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := ParseGeneralSummary(overview, false, false, false)
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
@@ -578,6 +590,7 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
Networks: 10.10.0.0/24
Forwarding rules: 0
Peers count: 2/2 Connected

View File

@@ -72,6 +72,12 @@ type Info struct {
BlockInbound bool
LazyConnectionEnabled bool
EnableSSHRoot bool
EnableSSHSFTP bool
EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
}
func (i *Info) SetFlags(
@@ -79,6 +85,8 @@ func (i *Info) SetFlags(
serverSSHAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
@@ -94,6 +102,22 @@ func (i *Info) SetFlags(
i.BlockInbound = blockInbound
i.LazyConnectionEnabled = lazyConnectionEnabled
if enableSSHRoot != nil {
i.EnableSSHRoot = *enableSSHRoot
}
if enableSSHSFTP != nil {
i.EnableSSHSFTP = *enableSSHSFTP
}
if enableSSHLocalPortForwarding != nil {
i.EnableSSHLocalPortForwarding = *enableSSHLocalPortForwarding
}
if enableSSHRemotePortForwarding != nil {
i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
}
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context

View File

@@ -55,6 +55,7 @@ const (
const (
censoredPreSharedKey = "**********"
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
)
func main() {
@@ -265,25 +266,38 @@ type serviceClient struct {
iMTU *widget.Entry
// switch elements for settings form
sRosenpassPermissive *widget.Check
sNetworkMonitor *widget.Check
sDisableDNS *widget.Check
sDisableClientRoutes *widget.Check
sDisableServerRoutes *widget.Check
sBlockLANAccess *widget.Check
sRosenpassPermissive *widget.Check
sNetworkMonitor *widget.Check
sDisableDNS *widget.Check
sDisableClientRoutes *widget.Check
sDisableServerRoutes *widget.Check
sBlockLANAccess *widget.Check
sEnableSSHRoot *widget.Check
sEnableSSHSFTP *widget.Check
sEnableSSHLocalPortForward *widget.Check
sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
iSSHJWTCacheTTL *widget.Entry
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
preSharedKey string
RosenpassPermissive bool
interfaceName string
interfacePort int
mtu uint16
networkMonitor bool
disableDNS bool
disableClientRoutes bool
disableServerRoutes bool
blockLANAccess bool
managementURL string
preSharedKey string
RosenpassPermissive bool
interfaceName string
interfacePort int
mtu uint16
networkMonitor bool
disableDNS bool
disableClientRoutes bool
disableServerRoutes bool
blockLANAccess bool
enableSSHRoot bool
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
connected bool
update *version.Update
@@ -435,18 +449,22 @@ func (s *serviceClient) showSettingsUI() {
s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil)
s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil)
s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil)
s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil)
s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.iSSHJWTCacheTTL = widget.NewEntry()
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 500))
s.wSettings.Resize(fyne.NewSize(600, 400))
s.wSettings.SetFixedSize(true)
s.getSrvConfig()
s.wSettings.Show()
}
// getSettingsForm to embed it into settings window.
func (s *serviceClient) getSettingsForm() *widget.Form {
func (s *serviceClient) getConnectionForm() *widget.Form {
var activeProfName string
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
@@ -457,153 +475,277 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Profile", Widget: widget.NewLabel(activeProfName)},
{Text: "Management URL", Widget: s.iMngURL},
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
{Text: "Interface Name", Widget: s.iInterfaceName},
{Text: "Interface Port", Widget: s.iInterfacePort},
{Text: "MTU", Widget: s.iMTU},
{Text: "Management URL", Widget: s.iMngURL},
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
{Text: "Log File", Widget: s.iLogFile},
},
}
}
func (s *serviceClient) saveSettings() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
return
}
if err := s.validateSettings(); err != nil {
dialog.ShowError(err, s.wSettings)
return
}
port, mtu, err := s.parseNumericSettings()
if err != nil {
dialog.ShowError(err, s.wSettings)
return
}
iMngURL := strings.TrimSpace(s.iMngURL.Text)
if s.hasSettingsChanged(iMngURL, port, mtu) {
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
dialog.ShowError(err, s.wSettings)
return
}
}
s.wSettings.Close()
}
func (s *serviceClient) validateSettings() error {
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
return fmt.Errorf("Invalid Pre-shared Key Value")
}
}
return nil
}
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
if err != nil {
return 0, 0, errors.New("Invalid interface port")
}
if port < 1 || port > 65535 {
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
}
var mtu int64
mtuText := strings.TrimSpace(s.iMTU.Text)
if mtuText != "" {
mtu, err = strconv.ParseInt(mtuText, 10, 64)
if err != nil {
return 0, 0, errors.New("Invalid MTU value")
}
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
}
}
return port, mtu, nil
}
func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool {
return s.managementURL != iMngURL ||
s.preSharedKey != s.iPreSharedKey.Text ||
s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
s.interfaceName != s.iInterfaceName.Text ||
s.interfacePort != int(port) ||
s.mtu != uint16(mtu) ||
s.networkMonitor != s.sNetworkMonitor.Checked ||
s.disableDNS != s.sDisableDNS.Checked ||
s.disableClientRoutes != s.sDisableClientRoutes.Checked ||
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
s.hasSSHChanges()
}
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
s.managementURL = iMngURL
s.preSharedKey = s.iPreSharedKey.Text
s.mtu = uint16(mtu)
req, err := s.buildSetConfigRequest(iMngURL, port, mtu)
if err != nil {
return fmt.Errorf("build config request: %w", err)
}
if err := s.sendConfigUpdate(req); err != nil {
return fmt.Errorf("set configuration: %w", err)
}
return nil
}
func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) {
currUser, err := user.Current()
if err != nil {
return nil, fmt.Errorf("get current user: %w", err)
}
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
return nil, fmt.Errorf("get active profile: %w", err)
}
req := &proto.SetConfigRequest{
ProfileName: activeProf.Name,
Username: currUser.Username,
}
if iMngURL != "" {
req.ManagementUrl = iMngURL
}
req.RosenpassPermissive = &s.sRosenpassPermissive.Checked
req.InterfaceName = &s.iInterfaceName.Text
req.WireguardPort = &port
if mtu > 0 {
req.Mtu = &mtu
}
req.NetworkMonitor = &s.sNetworkMonitor.Checked
req.DisableDns = &s.sDisableDNS.Checked
req.DisableClientRoutes = &s.sDisableClientRoutes.Checked
req.DisableServerRoutes = &s.sDisableServerRoutes.Checked
req.BlockLanAccess = &s.sBlockLANAccess.Checked
req.EnableSSHRoot = &s.sEnableSSHRoot.Checked
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
if sshJWTCacheTTLText != "" {
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
if err != nil {
return nil, errors.New("Invalid SSH JWT Cache TTL value")
}
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)
}
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if s.iPreSharedKey.Text != censoredPreSharedKey {
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
}
return req, nil
}
func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return fmt.Errorf("get client: %w", err)
}
_, err = conn.SetConfig(s.ctx, req)
if err != nil {
return fmt.Errorf("set config: %w", err)
}
// Reconnect if connected to apply the new settings
go func() {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return
}
if status.Status == string(internal.StatusConnected) {
// run down & up
_, err = conn.Down(s.ctx, &proto.DownRequest{})
if err != nil {
log.Errorf("down service: %v", err)
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
log.Errorf("up service: %v", err)
return
}
}
}()
return nil
}
func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
connectionForm := s.getConnectionForm()
networkForm := s.getNetworkForm()
sshForm := s.getSSHForm()
tabs := container.NewAppTabs(
container.NewTabItem("Connection", connectionForm),
container.NewTabItem("Network", networkForm),
container.NewTabItem("SSH", sshForm),
)
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
saveButton.Importance = widget.HighImportance
cancelButton := widget.NewButtonWithIcon("Cancel", theme.CancelIcon(), func() {
s.wSettings.Close()
})
buttonContainer := container.NewHBox(
layout.NewSpacer(),
cancelButton,
saveButton,
)
return container.NewBorder(nil, buttonContainer, nil, nil, tabs)
}
func (s *serviceClient) getNetworkForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Network Monitor", Widget: s.sNetworkMonitor},
{Text: "Disable DNS", Widget: s.sDisableDNS},
{Text: "Disable Client Routes", Widget: s.sDisableClientRoutes},
{Text: "Disable Server Routes", Widget: s.sDisableServerRoutes},
{Text: "Disable LAN Access", Widget: s.sBlockLANAccess},
},
SubmitText: "Save",
OnSubmit: func() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
return
}
}
}
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
// validate preSharedKey if it added
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings)
return
}
}
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
if err != nil {
dialog.ShowError(errors.New("Invalid interface port"), s.wSettings)
return
}
var mtu int64
mtuText := strings.TrimSpace(s.iMTU.Text)
if mtuText != "" {
var err error
mtu, err = strconv.ParseInt(mtuText, 10, 64)
if err != nil {
dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings)
return
}
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings)
return
}
}
iMngURL := strings.TrimSpace(s.iMngURL.Text)
defer s.wSettings.Close()
// Check if any settings have changed
if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text ||
s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) ||
s.mtu != uint16(mtu) ||
s.networkMonitor != s.sNetworkMonitor.Checked ||
s.disableDNS != s.sDisableDNS.Checked ||
s.disableClientRoutes != s.sDisableClientRoutes.Checked ||
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
s.blockLANAccess != s.sBlockLANAccess.Checked {
s.managementURL = iMngURL
s.preSharedKey = s.iPreSharedKey.Text
s.mtu = uint16(mtu)
currUser, err := user.Current()
if err != nil {
log.Errorf("get current user: %v", err)
return
}
var req proto.SetConfigRequest
req.ProfileName = activeProf.Name
req.Username = currUser.Username
if iMngURL != "" {
req.ManagementUrl = iMngURL
}
req.RosenpassPermissive = &s.sRosenpassPermissive.Checked
req.InterfaceName = &s.iInterfaceName.Text
req.WireguardPort = &port
if mtu > 0 {
req.Mtu = &mtu
}
req.NetworkMonitor = &s.sNetworkMonitor.Checked
req.DisableDns = &s.sDisableDNS.Checked
req.DisableClientRoutes = &s.sDisableClientRoutes.Checked
req.DisableServerRoutes = &s.sDisableServerRoutes.Checked
req.BlockLanAccess = &s.sBlockLANAccess.Checked
if s.iPreSharedKey.Text != censoredPreSharedKey {
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
}
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
log.Errorf("get client: %v", err)
dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings)
return
}
_, err = conn.SetConfig(s.ctx, &req)
if err != nil {
log.Errorf("set config: %v", err)
dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings)
return
}
go func() {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
return
}
if status.Status == string(internal.StatusConnected) {
// run down & up
_, err = conn.Down(s.ctx, &proto.DownRequest{})
if err != nil {
log.Errorf("down service: %v", err)
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
log.Errorf("up service: %v", err)
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
return
}
}
}()
}
},
OnCancel: func() {
s.wSettings.Close()
func (s *serviceClient) getSSHForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Enable SSH Root Login", Widget: s.sEnableSSHRoot},
{Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
{Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL},
},
}
}
func (s *serviceClient) hasSSHChanges() bool {
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
val, err := strconv.Atoi(text)
if err != nil {
return true
}
currentSSHJWTCacheTTL = val
}
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
s.disableSSHAuth != s.sDisableSSHAuth.Checked ||
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
}
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -1123,6 +1265,25 @@ func (s *serviceClient) getSrvConfig() {
s.disableServerRoutes = cfg.DisableServerRoutes
s.blockLANAccess = cfg.BlockLANAccess
if cfg.EnableSSHRoot != nil {
s.enableSSHRoot = *cfg.EnableSSHRoot
}
if cfg.EnableSSHSFTP != nil {
s.enableSSHSFTP = *cfg.EnableSSHSFTP
}
if cfg.EnableSSHLocalPortForwarding != nil {
s.enableSSHLocalPortForward = *cfg.EnableSSHLocalPortForwarding
}
if cfg.EnableSSHRemotePortForwarding != nil {
s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
}
if cfg.DisableSSHAuth != nil {
s.disableSSHAuth = *cfg.DisableSSHAuth
}
if cfg.SSHJWTCacheTTL != nil {
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
}
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
s.iPreSharedKey.SetText(cfg.PreSharedKey)
@@ -1143,6 +1304,24 @@ func (s *serviceClient) getSrvConfig() {
s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes)
s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes)
s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess)
if cfg.EnableSSHRoot != nil {
s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot)
}
if cfg.EnableSSHSFTP != nil {
s.sEnableSSHSFTP.SetChecked(*cfg.EnableSSHSFTP)
}
if cfg.EnableSSHLocalPortForwarding != nil {
s.sEnableSSHLocalPortForward.SetChecked(*cfg.EnableSSHLocalPortForwarding)
}
if cfg.EnableSSHRemotePortForwarding != nil {
s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
}
if cfg.DisableSSHAuth != nil {
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
}
if cfg.SSHJWTCacheTTL != nil {
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
}
}
if s.mNotifications == nil {
@@ -1213,6 +1392,15 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableServerRoutes = cfg.DisableServerRoutes
config.BlockLANAccess = cfg.BlockLanAccess
config.EnableSSHRoot = &cfg.EnableSSHRoot
config.EnableSSHSFTP = &cfg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
config.DisableSSHAuth = &cfg.DisableSSHAuth
ttl := int(cfg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
return &config
}

Some files were not shown because too many files have changed in this diff Show More