mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
52
.github/workflows/check-license-dependencies.yml
vendored
52
.github/workflows/check-license-dependencies.yml
vendored
@@ -19,35 +19,37 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ ! -z "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
echo ""
|
||||
echo "✅ All internal license dependencies are clean"
|
||||
fi
|
||||
done
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo ""
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All internal license dependencies are clean"
|
||||
fi
|
||||
|
||||
check-external-licenses:
|
||||
name: Check External GPL/AGPL Licenses
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetEnableSSHRoot reads SSH root login setting from config file
|
||||
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
|
||||
if p.configInput.EnableSSHRoot != nil {
|
||||
return *p.configInput.EnableSSHRoot, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRoot == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRoot, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRoot stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
|
||||
p.configInput.EnableSSHRoot = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHSFTP reads SSH SFTP setting from config file
|
||||
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
|
||||
if p.configInput.EnableSSHSFTP != nil {
|
||||
return *p.configInput.EnableSSHSFTP, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHSFTP == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHSFTP, err
|
||||
}
|
||||
|
||||
// SetEnableSSHSFTP stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
|
||||
p.configInput.EnableSSHSFTP = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHLocalPortForwarding != nil {
|
||||
return *p.configInput.EnableSSHLocalPortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHLocalPortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHLocalPortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHLocalPortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHRemotePortForwarding != nil {
|
||||
return *p.configInput.EnableSSHRemotePortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRemotePortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRemotePortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHRemotePortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
|
||||
@@ -35,7 +35,6 @@ const (
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
@@ -64,7 +63,6 @@ var (
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
@@ -176,7 +174,6 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
|
||||
|
||||
@@ -3,125 +3,809 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
userName = "root"
|
||||
host string
|
||||
const (
|
||||
sshUsernameDesc = "SSH username"
|
||||
hostArgumentRequired = "host argument required"
|
||||
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
enableSSHRootFlag = "enable-ssh-root"
|
||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
}
|
||||
var (
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
localForwards []string
|
||||
remoteForwards []string
|
||||
strictHostKeyChecking bool
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
requestPTY bool
|
||||
)
|
||||
|
||||
split := strings.Split(args[0], "@")
|
||||
if len(split) == 2 {
|
||||
userName = split[0]
|
||||
host = split[1]
|
||||
} else {
|
||||
host = args[0]
|
||||
}
|
||||
var (
|
||||
serverSSHAllowed bool
|
||||
enableSSHRoot bool
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
Short: "Connect to a remote SSH server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
|
||||
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
|
||||
if !util.IsAdmin() {
|
||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sm := profilemanager.NewServiceManager(configPath)
|
||||
activeProf, err := sm.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
profPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile path: %v", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.ReadConfig(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read profile config: %v", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
sshCmd.AddCommand(sshSftpCmd)
|
||||
sshCmd.AddCommand(sshProxyCmd)
|
||||
sshCmd.AddCommand(sshDetectCmd)
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
return
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [flags] [user@]host [command]",
|
||||
Short: "Connect to a NetBird peer via SSH",
|
||||
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||
|
||||
Port Forwarding:
|
||||
-L [bind_address:]port:host:hostport Local port forwarding
|
||||
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||
|
||||
SSH Options:
|
||||
-p, --port int Remote SSH port (default 22)
|
||||
-u, --user string SSH username
|
||||
--login string SSH username (alias for --user)
|
||||
-t, --tty Force pseudo-terminal allocation
|
||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||
-o, --known-hosts string Path to known_hosts file
|
||||
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh root@peer-hostname
|
||||
netbird ssh --login root peer-hostname
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami
|
||||
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
|
||||
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||
DisableFlagParsing: true,
|
||||
Args: validateSSHArgsWithoutFlagParsing,
|
||||
RunE: sshFn,
|
||||
Aliases: []string{"ssh"},
|
||||
}
|
||||
|
||||
func sshFn(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
}
|
||||
}
|
||||
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err = c.OpenTerminal()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
<-sshctx.Done()
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
return err
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
||||
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||
var localForwardFlags []string
|
||||
var remoteForwardFlags []string
|
||||
var filteredArgs []string
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case strings.HasPrefix(arg, "-L"):
|
||||
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
|
||||
case strings.HasPrefix(arg, "-R"):
|
||||
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
|
||||
default:
|
||||
filteredArgs = append(filteredArgs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||
}
|
||||
|
||||
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
|
||||
if arg == "-L" || arg == "-R" {
|
||||
if i+1 < len(args) {
|
||||
flags = append(flags, args[i+1])
|
||||
i++
|
||||
}
|
||||
} else if len(arg) > 2 {
|
||||
flags = append(flags, arg[2:])
|
||||
}
|
||||
return flags, i
|
||||
}
|
||||
|
||||
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||
func extractGlobalFlags(args []string) {
|
||||
sshPos := findSSHCommandPosition(args)
|
||||
if sshPos == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
globalArgs := args[:sshPos]
|
||||
parseGlobalArgs(globalArgs)
|
||||
}
|
||||
|
||||
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||
func findSSHCommandPosition(args []string) int {
|
||||
for i, arg := range args {
|
||||
if arg == "ssh" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
const (
|
||||
configFlag = "config"
|
||||
logLevelFlag = "log-level"
|
||||
logFileFlag = "log-file"
|
||||
)
|
||||
|
||||
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||
func parseGlobalArgs(globalArgs []string) {
|
||||
flagHandlers := map[string]func(string){
|
||||
configFlag: func(value string) { configPath = value },
|
||||
logLevelFlag: func(value string) { logLevel = value },
|
||||
logFileFlag: func(value string) {
|
||||
if !slices.Contains(logFiles, value) {
|
||||
logFiles = append(logFiles, value)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shortFlags := map[string]string{
|
||||
"c": configFlag,
|
||||
"l": logLevelFlag,
|
||||
}
|
||||
|
||||
for i := 0; i < len(globalArgs); i++ {
|
||||
arg := globalArgs[i]
|
||||
|
||||
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||
i = nextIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseFlag handles generic flag parsing for both long and short forms
|
||||
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex
|
||||
}
|
||||
|
||||
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex + 1
|
||||
}
|
||||
|
||||
return false, currentIndex
|
||||
}
|
||||
|
||||
type parsedFlag struct {
|
||||
flagName string
|
||||
value string
|
||||
}
|
||||
|
||||
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if !strings.Contains(arg, "=") {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(arg, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "--") {
|
||||
flagName := strings.TrimPrefix(parts[0], "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// parseSpacedFormat handles --flag value and -f value formats
|
||||
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if currentIndex+1 >= len(args) {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
flagName := strings.TrimPrefix(arg, "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||
shortFlag := strings.TrimPrefix(arg, "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
// sshFlags contains all SSH-related flags and parameters
|
||||
type sshFlags struct {
|
||||
Port int
|
||||
Username string
|
||||
Login string
|
||||
RequestPTY bool
|
||||
StrictHostKeyChecking bool
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
RemoteForwards []string
|
||||
Host string
|
||||
Command string
|
||||
}
|
||||
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
|
||||
flags := &sshFlags{}
|
||||
|
||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
|
||||
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
|
||||
|
||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
|
||||
|
||||
return fs, flags
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
resetSSHGlobals()
|
||||
|
||||
if len(os.Args) > 2 {
|
||||
extractGlobalFlags(os.Args[1:])
|
||||
}
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, flags := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
if errors.Is(err, flag.ErrHelp) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
remaining := fs.Args()
|
||||
if len(remaining) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
port = flags.Port
|
||||
if flags.Username != "" {
|
||||
username = flags.Username
|
||||
} else if flags.Login != "" {
|
||||
username = flags.Login
|
||||
}
|
||||
|
||||
requestPTY = flags.RequestPTY
|
||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
}
|
||||
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = flags.LogLevel
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
remoteForwards = remoteForwardFlags
|
||||
|
||||
return parseHostnameAndCommand(remaining)
|
||||
}
|
||||
|
||||
func parseHostnameAndCommand(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
arg := args[0]
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
if username == "" {
|
||||
username = parts[0]
|
||||
}
|
||||
host = parts[1]
|
||||
} else {
|
||||
host = arg
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
username = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
username = currentUser.Username
|
||||
} else {
|
||||
username = "root"
|
||||
}
|
||||
}
|
||||
|
||||
// Everything after hostname becomes the command
|
||||
if len(args) > 1 {
|
||||
command = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
|
||||
sshCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-sshCtx.Done()
|
||||
if err := c.Close(); err != nil {
|
||||
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||
return fmt.Errorf("start port forwarding: %w", err)
|
||||
}
|
||||
|
||||
if command != "" {
|
||||
return executeSSHCommand(sshCtx, c, command)
|
||||
}
|
||||
return openSSHTerminal(sshCtx, c)
|
||||
}
|
||||
|
||||
// executeSSHCommand executes a command over SSH.
|
||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||
var err error
|
||||
if requestPTY {
|
||||
err = c.ExecuteCommandWithPTY(ctx, command)
|
||||
} else {
|
||||
err = c.ExecuteCommandWithIO(ctx, command)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitErr *ssh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
os.Exit(exitErr.ExitStatus())
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote command exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openSSHTerminal opens an interactive SSH terminal.
|
||||
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exitMissingErr *ssh.ExitMissingError
|
||||
if errors.As(err, &exitMissingErr) {
|
||||
log.Debugf("Remote terminal exited without exit status: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("open terminal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||
for _, forward := range localForwards {
|
||||
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, forward := range remoteForwards {
|
||||
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Local port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Remote port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
// Support formats:
|
||||
// port:host:hostport -> localhost:port -> host:hostport
|
||||
// host:port:host:hostport -> host:port -> host:hostport
|
||||
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||
|
||||
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||
return parseIPv6ForwardSpec(spec)
|
||||
}
|
||||
|
||||
parts := strings.Split(spec, ":")
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||
}
|
||||
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
return parseTwoPartForwardSpec(parts, spec)
|
||||
case 3:
|
||||
return parseThreePartForwardSpec(parts)
|
||||
case 4:
|
||||
return parseFourPartForwardSpec(parts)
|
||||
default:
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||
}
|
||||
}
|
||||
|
||||
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||
if isUnixSocket(parts[1]) {
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||
}
|
||||
|
||||
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||
if isUnixSocket(parts[2]) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2] + ":" + parts[3]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||
idx := strings.Index(spec, "]:")
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||
}
|
||||
|
||||
ipv6Host := spec[:idx+1]
|
||||
remaining := spec[idx+2:]
|
||||
|
||||
parts := strings.Split(remaining, ":")
|
||||
if len(parts) != 3 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||
}
|
||||
|
||||
localAddr := ipv6Host + ":" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// isUnixSocket checks if a path is a Unix socket path.
|
||||
func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
var sshProxyCmd = &cobra.Command{
|
||||
Use: "proxy <host> <port>",
|
||||
Short: "Internal SSH proxy for native SSH client integration",
|
||||
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshProxyFn,
|
||||
}
|
||||
|
||||
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := proxy.Close(); err != nil {
|
||||
log.Debugf("close SSH proxy: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshDetectCmd = &cobra.Command{
|
||||
Use: "detect <host> <port>",
|
||||
Short: "Detect if a host is running NetBird SSH",
|
||||
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshDetectFn,
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sshExecUID uint32
|
||||
sshExecGID uint32
|
||||
sshExecGroups []uint
|
||||
sshExecWorkingDir string
|
||||
sshExecShell string
|
||||
sshExecCommand string
|
||||
sshExecPTY bool
|
||||
)
|
||||
|
||||
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||
var sshExecCmd = &cobra.Command{
|
||||
Use: "exec",
|
||||
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||
Hidden: true,
|
||||
RunE: runSSHExec,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||
|
||||
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sshCmd.AddCommand(sshExecCmd)
|
||||
}
|
||||
|
||||
// runSSHExec handles the SSH exec subcommand execution.
|
||||
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sshExecGroups {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sshExecUID,
|
||||
GID: sshExecGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sshExecWorkingDir,
|
||||
Shell: sshExecShell,
|
||||
Command: sshExecCommand,
|
||||
PTY: sshExecPTY,
|
||||
}
|
||||
|
||||
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpUID uint32
|
||||
sftpGID uint32
|
||||
sftpGroupsInt []uint
|
||||
sftpWorkingDir string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with privilege dropping (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sftpGroupsInt {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sftpUID,
|
||||
GID: sftpGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sftpWorkingDir,
|
||||
Shell: "",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Tracef("starting SFTP server with dropped privileges")
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeSuccess)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_windows.go
Normal file
94
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/user"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpWorkingDir string
|
||||
windowsUsername string
|
||||
windowsDomain string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with user switching for Windows (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
return sftpMainDirect(cmd)
|
||||
}
|
||||
|
||||
func sftpMainDirect(cmd *cobra.Command) error {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
|
||||
if windowsUsername != "" {
|
||||
expectedUsername := windowsUsername
|
||||
if windowsDomain != "" {
|
||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||
}
|
||||
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
|
||||
cmd.PrintErrf("user switching failed\n")
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||
|
||||
if sftpWorkingDir != "" {
|
||||
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Debugf("starting SFTP server")
|
||||
exitCode := sshserver.ExitCodeSuccess
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
exitCode = sshserver.ExitCodeShellExecFail
|
||||
}
|
||||
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("SFTP server close error: %v", err)
|
||||
}
|
||||
|
||||
os.Exit(exitCode)
|
||||
return nil
|
||||
}
|
||||
717
client/cmd/ssh_test.go
Normal file
717
client/cmd/ssh_test.go
Normal file
@@ -0,0 +1,717 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedUser string
|
||||
expectedPort int
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "basic host",
|
||||
args: []string{"hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "user@host format",
|
||||
args: []string{"user@hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "user",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "host with command",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "command with flags should be preserved",
|
||||
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "ls -la /tmp",
|
||||
},
|
||||
{
|
||||
name: "double dash separator",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "-- ls -la",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
// Mock command for testing
|
||||
cmd := sshCmd
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
if tt.expectedUser != "" {
|
||||
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||
}
|
||||
assert.Equal(t, tt.expectedPort, port, "port mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||
// Test that SSH flags don't conflict with command flags
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flags",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "ls flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "grep with -r flag",
|
||||
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
|
||||
expectedCmd: "grep -r pattern /path",
|
||||
description: "grep flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "ps with aux flags",
|
||||
args: []string{"hostname", "ps", "aux"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "ps flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "command with double dash",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedCmd: "-- ls -la",
|
||||
description: "double dash should be preserved in command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||
// Test that commands with arguments should execute the command and exit,
|
||||
// not drop to an interactive shell
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
shouldExit bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls command should execute and exit",
|
||||
args: []string{"hostname", "ls"},
|
||||
expectedCmd: "ls",
|
||||
shouldExit: true,
|
||||
description: "ls command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "ls with flags should execute and exit",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
shouldExit: true,
|
||||
description: "ls with flags should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "pwd command should execute and exit",
|
||||
args: []string{"hostname", "pwd"},
|
||||
expectedCmd: "pwd",
|
||||
shouldExit: true,
|
||||
description: "pwd command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "echo command should execute and exit",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedCmd: "echo hello",
|
||||
shouldExit: true,
|
||||
description: "echo command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "no command should open shell",
|
||||
args: []string{"hostname"},
|
||||
expectedCmd: "",
|
||||
shouldExit: false,
|
||||
description: "no command should open interactive shell",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// When command is present, it should execute the command and exit
|
||||
// When command is empty, it should open interactive shell
|
||||
hasCommand := command != ""
|
||||
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||
// Test that flags after hostname are not parsed by netbird but passed to SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flag should not be parsed by netbird",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
|
||||
},
|
||||
{
|
||||
name: "command with netbird-like flags should be passed through",
|
||||
args: []string{"hostname", "echo", "--help"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "echo --help",
|
||||
expectError: false,
|
||||
description: "--help should be passed to echo, not parsed by netbird",
|
||||
},
|
||||
{
|
||||
name: "command with -p flag should not conflict with SSH port flag",
|
||||
args: []string{"hostname", "ps", "-p", "1234"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ps -p 1234",
|
||||
expectError: false,
|
||||
description: "ps -p should be passed to ps command, not parsed as port",
|
||||
},
|
||||
{
|
||||
name: "tar with flags should be passed through",
|
||||
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "tar -czf backup.tar.gz /home",
|
||||
expectError: false,
|
||||
description: "tar flags should be passed to tar command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
|
||||
// should not parse -la as netbird flags but pass them to the SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "original issue: ls -la should be preserved",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "The original failing case should now work",
|
||||
},
|
||||
{
|
||||
name: "ls -l should be preserved",
|
||||
args: []string{"hostname", "ls", "-l"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -l",
|
||||
expectError: false,
|
||||
description: "Single letter flags should be preserved",
|
||||
},
|
||||
{
|
||||
name: "SSH port flag should work",
|
||||
args: []string{"-p", "2222", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "SSH -p flag should be parsed, command flags preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// Check port for the test case with -p flag
|
||||
if len(tt.args) > 0 && tt.args[0] == "-p" {
|
||||
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local port forwarding -L",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Single -L flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "remote port forwarding -R",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectError: false,
|
||||
description: "Single -R flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple local port forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Multiple -L flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple remote port forwards",
|
||||
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Multiple -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed local and remote forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Mixed -L and -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with bind address",
|
||||
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with command should work",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec string
|
||||
expectedLocal string
|
||||
expectedRemote string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "simple port forward",
|
||||
spec: "8080:localhost:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Simple port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with bind address",
|
||||
spec: "127.0.0.1:8080:localhost:80",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "bind_address:port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward to different host",
|
||||
spec: "8080:example.com:443",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "example.com:443",
|
||||
expectError: false,
|
||||
description: "Forwarding to different host should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with IPv6 (needs bracket support)",
|
||||
spec: "::1:8080:localhost:80",
|
||||
expectError: true,
|
||||
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too few parts",
|
||||
spec: "8080:localhost",
|
||||
expectError: true,
|
||||
description: "Invalid format with too few parts should fail",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too many parts",
|
||||
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||
expectError: true,
|
||||
description: "Invalid format with too many parts should fail",
|
||||
},
|
||||
{
|
||||
name: "empty spec",
|
||||
spec: "",
|
||||
expectError: true,
|
||||
description: "Empty spec should fail",
|
||||
},
|
||||
{
|
||||
name: "unix socket local forward",
|
||||
spec: "8080:/tmp/socket",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket forwarding should work",
|
||||
},
|
||||
{
|
||||
name: "unix socket with bind address",
|
||||
spec: "127.0.0.1:8080:/tmp/socket",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
spec: "8080:*:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "*:80",
|
||||
expectError: false,
|
||||
description: "Wildcard in remote host should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||
// Integration test for port forwarding with the actual SSH command implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local forward with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectedCmd: "echo test",
|
||||
description: "Local forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "remote forward with command",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "Remote forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "multiple forwards with user and command",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "Complex case with multiple forwards, user, and command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
}{
|
||||
{
|
||||
name: "cmd flag passed as command",
|
||||
args: []string{"hostname", "--cmd", "echo test"},
|
||||
expectedCmd: "--cmd echo test",
|
||||
},
|
||||
{
|
||||
name: "uid flag passed as command",
|
||||
args: []string{"hostname", "--uid", "1000"},
|
||||
expectedCmd: "--uid 1000",
|
||||
},
|
||||
{
|
||||
name: "shell flag passed as command",
|
||||
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||
expectedCmd: "--shell /bin/bash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "hostname", host)
|
||||
assert.Equal(t, tt.expectedCmd, command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
|
||||
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "invalid long flag before hostname",
|
||||
args: []string{"--invalid-flag", "hostname"},
|
||||
description: "Invalid flag should return parse error, not treat flag as hostname",
|
||||
},
|
||||
{
|
||||
name: "invalid short flag before hostname",
|
||||
args: []string{"-x", "hostname"},
|
||||
description: "Invalid short flag should return parse error",
|
||||
},
|
||||
{
|
||||
name: "invalid flag with value before hostname",
|
||||
args: []string{"--invalid-option=value", "hostname"},
|
||||
description: "Invalid flag with value should return parse error",
|
||||
},
|
||||
{
|
||||
name: "typo in known flag",
|
||||
args: []string{"--por", "2222", "hostname"},
|
||||
description: "Typo in flag name should return parse error (not silently ignored)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
|
||||
// Should return an error for invalid flags
|
||||
assert.Error(t, err, tt.description)
|
||||
|
||||
// Should not have set host to the invalid flag
|
||||
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -117,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
@@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
@@ -18,12 +18,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
var (
|
||||
ErrClientAlreadyStarted = errors.New("client already started")
|
||||
ErrClientNotStarted = errors.New("client not started")
|
||||
ErrEngineNotStarted = errors.New("engine not started")
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return c.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return sshcommon.ErrPeerNotFound
|
||||
}
|
||||
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect == nil {
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
return nil, ErrEngineNotStarted
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
|
||||
@@ -35,6 +35,12 @@ const (
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
@@ -59,12 +65,6 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestShouldForward(t *testing.T) {
|
||||
// Set up test addresses
|
||||
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||
|
||||
// Create test manager with mock interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
// Set the mock to return our test WG IP
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Helper to create decoder with TCP packet
|
||||
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||
ipv4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: net.ParseIP("192.168.1.100"),
|
||||
DstIP: wgIP.AsSlice(),
|
||||
}
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: 54321,
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localForwarding bool
|
||||
netstack bool
|
||||
dstIP netip.Addr
|
||||
serviceRegistered bool
|
||||
servicePort uint16
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no local forwarding",
|
||||
localForwarding: false,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should never forward when local forwarding disabled",
|
||||
},
|
||||
{
|
||||
name: "traffic to other local interface",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: otherIP,
|
||||
expected: true,
|
||||
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||
},
|
||||
{
|
||||
name: "traffic to NetBird IP, no netstack",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners (final return false path)",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, no service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: true,
|
||||
description: "should forward when in netstack mode with no matching service",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, with service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
serviceRegistered: true,
|
||||
servicePort: 22,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners when service is registered",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
manager.localForwarding = tt.localForwarding
|
||||
manager.netstack = tt.netstack
|
||||
|
||||
// Register service if needed
|
||||
if tt.serviceRegistered {
|
||||
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
}
|
||||
|
||||
// Create decoder for the test
|
||||
decoder := createTCPDecoder(tt.servicePort)
|
||||
if !tt.serviceRegistered {
|
||||
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||
}
|
||||
|
||||
// Test the method
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestPortDNATBasic tests basic port DNAT functionality
|
||||
func TestPortDNATBasic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rule for peer B (the target)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d := parsePacket(t, packetAtoB)
|
||||
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
|
||||
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||
|
||||
// Verify port was translated to 22022
|
||||
d = parsePacket(t, packetAtoB)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||
|
||||
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
|
||||
// (prevents double NAT - original port stored in conntrack)
|
||||
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||
d2 := parsePacket(t, returnPacket)
|
||||
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
|
||||
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
|
||||
}
|
||||
|
||||
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||
func TestPortDNATMultipleRules(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rules for both peers
|
||||
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test traffic to peer B gets translated
|
||||
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d1 := parsePacket(t, packetToB)
|
||||
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
|
||||
require.True(t, translatedToB, "Traffic to peer B should be translated")
|
||||
d1 = parsePacket(t, packetToB)
|
||||
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
|
||||
|
||||
// Test traffic to peer A gets translated
|
||||
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||
d2 := parsePacket(t, packetToA)
|
||||
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
|
||||
require.True(t, translatedToA, "Traffic to peer A should be translated")
|
||||
d2 = parsePacket(t, packetToA)
|
||||
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
|
||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||
})
|
||||
}
|
||||
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
// we have old version of management without rules handling, we should allow all traffic
|
||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
||||
|
||||
@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
},
|
||||
},
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
expectedRules := 3
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
}
|
||||
|
||||
@@ -416,20 +416,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||
@@ -515,6 +520,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -453,6 +453,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRoot != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
|
||||
}
|
||||
if g.internalConfig.EnableSSHSFTP != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
|
||||
}
|
||||
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -30,7 +29,6 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -51,10 +49,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -115,7 +113,12 @@ type EngineConfig struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerSSHAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -148,8 +151,6 @@ type Engine struct {
|
||||
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
// sshMux protects sshServer field access
|
||||
sshMux sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
@@ -175,8 +176,7 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
sshServer sshServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -246,7 +246,6 @@ func NewEngine(
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
@@ -268,6 +267,7 @@ func NewEngine(
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -292,6 +292,12 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
if err := e.stopSSHServer(); err != nil {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
@@ -703,16 +709,10 @@ func (e *Engine) removeAllPeers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
||||
// removePeer closes an existing peer connection and removes a peer
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
@@ -884,6 +884,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -893,74 +898,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNil(server nbssh.Server) bool {
|
||||
return server == nil || reflect.ValueOf(server).IsNil()
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is not enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
if err != nil {
|
||||
e.sshMux.Unlock()
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
e.sshMux.Unlock()
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
e.sshMux.Unlock()
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else {
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -973,8 +910,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
err := e.updateSSH(conf.GetSshConfig())
|
||||
if err != nil {
|
||||
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling SSH server setup: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1012,6 +948,11 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1170,19 +1111,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
// update SSHServer by adding remote peer SSH keys
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
for _, config := range networkMap.GetRemotePeers() {
|
||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
||||
if err != nil {
|
||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||
|
||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1544,15 +1477,6 @@ func (e *Engine) close() {
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
if err != nil {
|
||||
@@ -1583,6 +1507,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
|
||||
355
client/internal/engine_ssh.go
Normal file
355
client/internal/engine_ssh.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type sshServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
GetStatus() (bool, []sshserver.SessionInfo)
|
||||
}
|
||||
|
||||
func (e *Engine) setupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("add SSH port redirection: %w", err)
|
||||
}
|
||||
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is disabled in config")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !sshConf.GetSshEnabled() {
|
||||
if e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is locally allowed but disabled by management server")
|
||||
}
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if e.sshServer != nil {
|
||||
log.Debug("SSH server is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||
return e.startSSHServer(nil)
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
jwtConfig := &sshserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audience: protoJWT.GetAudience(),
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
}
|
||||
|
||||
return e.startSSHServer(jwtConfig)
|
||||
}
|
||||
|
||||
return errors.New("SSH server requires valid JWT configuration")
|
||||
}
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
return nil
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
return nil // Don't fail engine startup on SSH config issues
|
||||
}
|
||||
|
||||
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||
|
||||
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to update SSH config state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||
var peerInfo []sshconfig.PeerSSHInfo
|
||||
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIP := e.extractPeerIP(peerConfig)
|
||||
hostname := e.extractHostname(peerConfig)
|
||||
|
||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||
Hostname: hostname,
|
||||
IP: peerIP,
|
||||
FQDN: peerConfig.GetFqdn(),
|
||||
})
|
||||
}
|
||||
|
||||
return peerInfo
|
||||
}
|
||||
|
||||
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||
return prefix.Addr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHostname extracts short hostname from FQDN
|
||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
fqdn := peerConfig.GetFqdn()
|
||||
if fqdn == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(fqdn, ".")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
|
||||
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||
}
|
||||
|
||||
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
statusRecorder := e.statusRecorder
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if statusRecorder == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
if len(peerState.SSHHostKey) > 0 {
|
||||
return peerState.SSHHostKey, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
log.Warnf("failed to remove SSH client config: %v", err)
|
||||
} else {
|
||||
log.Debugf("SSH client config cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
return fmt.Errorf("start SSH server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureSSHServer applies SSH configuration options to the server.
|
||||
func (e *Engine) configureSSHServer(server *sshserver.Server) {
|
||||
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
|
||||
server.SetAllowRootLogin(true)
|
||||
log.Info("SSH root login enabled")
|
||||
} else {
|
||||
server.SetAllowRootLogin(false)
|
||||
log.Info("SSH root login disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
|
||||
server.SetAllowSFTP(true)
|
||||
log.Info("SSH SFTP subsystem enabled")
|
||||
} else {
|
||||
server.SetAllowSFTP(false)
|
||||
log.Info("SSH SFTP subsystem disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
log.Info("SSH local port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowLocalPortForwarding(false)
|
||||
log.Info("SSH local port forwarding disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
|
||||
server.SetAllowRemotePortForwarding(true)
|
||||
log.Info("SSH remote port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowRemotePortForwarding(false)
|
||||
log.Info("SSH remote port forwarding disabled (default)")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("remove SSH port redirection: %w", err)
|
||||
}
|
||||
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopSSHServer() error {
|
||||
if e.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to cleanup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping SSH server")
|
||||
err := e.sshServer.Stop()
|
||||
e.sshServer = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSSHServerStatus returns the SSH server status and active sessions
|
||||
func (e *Engine) GetSSHServerStatus() (enabled bool, sessions []sshserver.SessionInfo) {
|
||||
e.syncMsgMux.Lock()
|
||||
sshServer := e.sshServer
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if sshServer == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return sshServer.GetStatus()
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -25,7 +24,10 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -46,7 +48,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -214,11 +216,13 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping TestEngine_SSH")
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -240,6 +244,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
@@ -250,35 +255,8 @@ func TestEngine_SSH(t *testing.T) {
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
var sshKeysAdded []string
|
||||
var sshPeersRemoved []string
|
||||
|
||||
sshCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
|
||||
return &ssh.MockServer{
|
||||
Ctx: sshCtx,
|
||||
StopFunc: func() error {
|
||||
cancel()
|
||||
return nil
|
||||
},
|
||||
StartFunc: func() error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
AddAuthorizedKeyFunc: func(peer, newKey string) error {
|
||||
sshKeysAdded = append(sshKeysAdded, newKey)
|
||||
return nil
|
||||
},
|
||||
RemoveAuthorizedKeyFunc: func(peer string) {
|
||||
sshPeersRemoved = append(sshPeersRemoved, peer)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
@@ -304,9 +282,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
@@ -314,19 +290,24 @@ func TestEngine_SSH(t *testing.T) {
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
@@ -336,13 +317,10 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
@@ -354,12 +332,70 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: false, // Start with SSH disabled
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Test SSH disabled config
|
||||
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
|
||||
err := engine.updateSSH(sshConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// Test inbound blocked
|
||||
engine.config.BlockInbound = true
|
||||
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
engine.config.BlockInbound = false
|
||||
|
||||
// Test with server SSH not allowed
|
||||
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHServerConsistency(t *testing.T) {
|
||||
|
||||
t.Run("server set only on successful creation", func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: true,
|
||||
SSHKey: []byte("test-key"),
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
engine.wgInterface = nil
|
||||
|
||||
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
})
|
||||
|
||||
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: false,
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
err := engine.stopSSHServer()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
@@ -1589,7 +1625,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -10,5 +11,8 @@ const (
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
@@ -21,9 +21,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const eventQueueSize = 10
|
||||
@@ -67,6 +67,7 @@ type State struct {
|
||||
BytesRx int64
|
||||
Latency time.Duration
|
||||
RosenpassEnabled bool
|
||||
SSHHostKey []byte
|
||||
routes map[string]struct{}
|
||||
}
|
||||
|
||||
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.SSHHostKey = sshHostKey
|
||||
d.peers[peerPubKey] = peerState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -44,24 +44,30 @@ var DefaultInterfaceBlacklist = []string{
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
|
||||
DisableClientRoutes *bool
|
||||
DisableServerRoutes *bool
|
||||
@@ -82,18 +88,24 @@ type ConfigInput struct {
|
||||
// Config Configuration type
|
||||
type Config struct {
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
SSHJWTCacheTTL *int
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -376,6 +388,62 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
} else {
|
||||
log.Infof("disabling SSH root login")
|
||||
}
|
||||
config.EnableSSHRoot = input.EnableSSHRoot
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||
if *input.EnableSSHSFTP {
|
||||
log.Infof("enabling SSH SFTP subsystem")
|
||||
} else {
|
||||
log.Infof("disabling SSH SFTP subsystem")
|
||||
}
|
||||
config.EnableSSHSFTP = input.EnableSSHSFTP
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||
if *input.EnableSSHLocalPortForwarding {
|
||||
log.Infof("enabling SSH local port forwarding")
|
||||
} else {
|
||||
log.Infof("disabling SSH local port forwarding")
|
||||
}
|
||||
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||
if *input.EnableSSHRemotePortForwarding {
|
||||
log.Infof("enabling SSH remote port forwarding")
|
||||
} else {
|
||||
log.Infof("disabling SSH remote port forwarding")
|
||||
}
|
||||
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||
if *input.DisableSSHAuth {
|
||||
log.Infof("disabling SSH authentication")
|
||||
} else {
|
||||
log.Infof("enabling SSH authentication")
|
||||
}
|
||||
config.DisableSSHAuth = input.DisableSSHAuth
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
|
||||
@@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) {
|
||||
|
||||
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no port specified uses default",
|
||||
|
||||
@@ -132,3 +132,21 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLoginHint retrieves the email from the active profile to use as login_hint.
|
||||
func GetLoginHint() string {
|
||||
pm := NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get active profile for login hint: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return profileState.Email
|
||||
}
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
@@ -39,6 +38,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -84,6 +84,15 @@ service DaemonService {
|
||||
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
|
||||
|
||||
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -161,6 +170,13 @@ message LoginRequest {
|
||||
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
optional string hint = 33;
|
||||
|
||||
optional bool enableSSHRoot = 34;
|
||||
optional bool enableSSHSFTP = 35;
|
||||
optional bool enableSSHLocalPortForwarding = 36;
|
||||
optional bool enableSSHRemotePortForwarding = 37;
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -188,9 +204,9 @@ message UpResponse {}
|
||||
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
bool shouldRunProbes = 2;
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
optional bool waitForReady = 3;
|
||||
optional bool waitForReady = 3;
|
||||
}
|
||||
|
||||
message StatusResponse{
|
||||
@@ -255,6 +271,18 @@ message GetConfigResponse {
|
||||
bool disable_server_routes = 19;
|
||||
|
||||
bool block_lan_access = 20;
|
||||
|
||||
bool enableSSHRoot = 21;
|
||||
|
||||
bool enableSSHSFTP = 24;
|
||||
|
||||
bool enableSSHLocalPortForwarding = 22;
|
||||
|
||||
bool enableSSHRemotePortForwarding = 23;
|
||||
|
||||
bool disableSSHAuth = 25;
|
||||
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -276,6 +304,7 @@ message PeerState {
|
||||
repeated string networks = 16;
|
||||
google.protobuf.Duration latency = 17;
|
||||
string relayAddress = 18;
|
||||
bytes sshHostKey = 19;
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -317,6 +346,20 @@ message NSGroupState {
|
||||
string error = 4;
|
||||
}
|
||||
|
||||
// SSHSessionInfo contains information about an active SSH session
|
||||
message SSHSessionInfo {
|
||||
string username = 1;
|
||||
string remoteAddress = 2;
|
||||
string command = 3;
|
||||
string jwtUsername = 4;
|
||||
}
|
||||
|
||||
// SSHServerState contains the latest state of the SSH server
|
||||
message SSHServerState {
|
||||
bool enabled = 1;
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -330,6 +373,7 @@ message FullStatus {
|
||||
repeated SystemEvent events = 7;
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -543,56 +587,63 @@ message SwitchProfileRequest {
|
||||
message SwitchProfileResponse {}
|
||||
|
||||
message SetConfigRequest {
|
||||
string username = 1;
|
||||
string profileName = 2;
|
||||
// managementUrl to authenticate.
|
||||
string managementUrl = 3;
|
||||
string username = 1;
|
||||
string profileName = 2;
|
||||
// managementUrl to authenticate.
|
||||
string managementUrl = 3;
|
||||
|
||||
// adminUrl to manage keys.
|
||||
string adminURL = 4;
|
||||
// adminUrl to manage keys.
|
||||
string adminURL = 4;
|
||||
|
||||
optional bool rosenpassEnabled = 5;
|
||||
optional bool rosenpassEnabled = 5;
|
||||
|
||||
optional string interfaceName = 6;
|
||||
optional string interfaceName = 6;
|
||||
|
||||
optional int64 wireguardPort = 7;
|
||||
optional int64 wireguardPort = 7;
|
||||
|
||||
optional string optionalPreSharedKey = 8;
|
||||
optional string optionalPreSharedKey = 8;
|
||||
|
||||
optional bool disableAutoConnect = 9;
|
||||
optional bool disableAutoConnect = 9;
|
||||
|
||||
optional bool serverSSHAllowed = 10;
|
||||
optional bool serverSSHAllowed = 10;
|
||||
|
||||
optional bool rosenpassPermissive = 11;
|
||||
optional bool rosenpassPermissive = 11;
|
||||
|
||||
optional bool networkMonitor = 12;
|
||||
optional bool networkMonitor = 12;
|
||||
|
||||
optional bool disable_client_routes = 13;
|
||||
optional bool disable_server_routes = 14;
|
||||
optional bool disable_dns = 15;
|
||||
optional bool disable_firewall = 16;
|
||||
optional bool block_lan_access = 17;
|
||||
optional bool disable_client_routes = 13;
|
||||
optional bool disable_server_routes = 14;
|
||||
optional bool disable_dns = 15;
|
||||
optional bool disable_firewall = 16;
|
||||
optional bool block_lan_access = 17;
|
||||
|
||||
optional bool disable_notifications = 18;
|
||||
optional bool disable_notifications = 18;
|
||||
|
||||
optional bool lazyConnectionEnabled = 19;
|
||||
optional bool lazyConnectionEnabled = 19;
|
||||
|
||||
optional bool block_inbound = 20;
|
||||
optional bool block_inbound = 20;
|
||||
|
||||
repeated string natExternalIPs = 21;
|
||||
bool cleanNATExternalIPs = 22;
|
||||
repeated string natExternalIPs = 21;
|
||||
bool cleanNATExternalIPs = 22;
|
||||
|
||||
bytes customDNSAddress = 23;
|
||||
bytes customDNSAddress = 23;
|
||||
|
||||
repeated string extraIFaceBlacklist = 24;
|
||||
repeated string extraIFaceBlacklist = 24;
|
||||
|
||||
repeated string dns_labels = 25;
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
bool cleanDNSLabels = 26;
|
||||
repeated string dns_labels = 25;
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
bool cleanDNSLabels = 26;
|
||||
|
||||
optional google.protobuf.Duration dnsRouteInterval = 27;
|
||||
optional google.protobuf.Duration dnsRouteInterval = 27;
|
||||
|
||||
optional int64 mtu = 28;
|
||||
optional int64 mtu = 28;
|
||||
|
||||
optional bool enableSSHRoot = 29;
|
||||
optional bool enableSSHSFTP = 30;
|
||||
optional bool enableSSHLocalPortForwarding = 31;
|
||||
optional bool enableSSHRemotePortForwarding = 32;
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
@@ -644,3 +695,63 @@ message GetFeaturesResponse{
|
||||
bool disable_profiles = 1;
|
||||
bool disable_update_settings = 2;
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
|
||||
message GetPeerSSHHostKeyRequest {
|
||||
// peer IP address or FQDN to get SSH host key for
|
||||
string peerAddress = 1;
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
|
||||
message GetPeerSSHHostKeyResponse {
|
||||
// SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
|
||||
bytes sshHostKey = 1;
|
||||
// peer IP address
|
||||
string peerIP = 2;
|
||||
// peer FQDN
|
||||
string peerFQDN = 3;
|
||||
// indicates if the SSH host key was found
|
||||
bool found = 4;
|
||||
}
|
||||
|
||||
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||
message RequestJWTAuthRequest {
|
||||
// hint for OIDC login_hint parameter (typically email address)
|
||||
optional string hint = 1;
|
||||
}
|
||||
|
||||
// RequestJWTAuthResponse contains authentication flow information
|
||||
message RequestJWTAuthResponse {
|
||||
// verification URI for user authentication
|
||||
string verificationURI = 1;
|
||||
// complete verification URI (with embedded user code)
|
||||
string verificationURIComplete = 2;
|
||||
// user code to enter on verification URI
|
||||
string userCode = 3;
|
||||
// device code for polling
|
||||
string deviceCode = 4;
|
||||
// expiration time in seconds
|
||||
int64 expiresIn = 5;
|
||||
// if a cached token is available, it will be returned here
|
||||
string cachedToken = 6;
|
||||
// maximum age of JWT tokens in seconds (from management server)
|
||||
int64 maxTokenAge = 7;
|
||||
}
|
||||
|
||||
// WaitJWTTokenRequest for waiting for authentication completion
|
||||
message WaitJWTTokenRequest {
|
||||
// device code from RequestJWTAuthResponse
|
||||
string deviceCode = 1;
|
||||
// user code for verification
|
||||
string userCode = 2;
|
||||
}
|
||||
|
||||
// WaitJWTTokenResponse contains the JWT token after authentication
|
||||
message WaitJWTTokenResponse {
|
||||
// JWT token (access token or ID token)
|
||||
string token = 1;
|
||||
// token type (e.g., "Bearer")
|
||||
string tokenType = 2;
|
||||
// expiration time in seconds
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
@@ -64,6 +64,12 @@ type DaemonServiceClient interface {
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
|
||||
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -349,6 +355,33 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) {
|
||||
out := new(GetPeerSSHHostKeyResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
|
||||
out := new(RequestJWTAuthResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
|
||||
out := new(WaitJWTTokenResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -399,6 +432,12 @@ type DaemonServiceServer interface {
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
|
||||
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -490,6 +529,15 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest)
|
||||
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -1010,6 +1058,60 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetPeerSSHHostKeyRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RequestJWTAuthRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(WaitJWTTokenRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/WaitJWTToken",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1125,6 +1227,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetFeatures",
|
||||
Handler: _DaemonService_GetFeatures_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetPeerSSHHostKey",
|
||||
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RequestJWTAuth",
|
||||
Handler: _DaemonService_RequestJWTAuth_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "WaitJWTToken",
|
||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
79
client/server/jwt_cache.go
Normal file
79
client/server/jwt_cache.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type jwtCache struct {
|
||||
mu sync.RWMutex
|
||||
enclave *memguard.Enclave
|
||||
expiresAt time.Time
|
||||
timer *time.Timer
|
||||
maxTokenSize int
|
||||
}
|
||||
|
||||
func newJWTCache() *jwtCache {
|
||||
return &jwtCache{
|
||||
maxTokenSize: 8192,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *jwtCache) store(token string, maxAge time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cleanup()
|
||||
|
||||
if c.timer != nil {
|
||||
c.timer.Stop()
|
||||
}
|
||||
|
||||
tokenBytes := []byte(token)
|
||||
c.enclave = memguard.NewEnclave(tokenBytes)
|
||||
|
||||
c.expiresAt = time.Now().Add(maxAge)
|
||||
|
||||
var timer *time.Timer
|
||||
timer = time.AfterFunc(maxAge, func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.timer != timer {
|
||||
return
|
||||
}
|
||||
c.cleanup()
|
||||
c.timer = nil
|
||||
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||
})
|
||||
c.timer = timer
|
||||
}
|
||||
|
||||
func (c *jwtCache) get() (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.enclave == nil || time.Now().After(c.expiresAt) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
buffer, err := c.enclave.Open()
|
||||
if err != nil {
|
||||
log.Debugf("Failed to open JWT token enclave: %v", err)
|
||||
return "", false
|
||||
}
|
||||
defer buffer.Destroy()
|
||||
|
||||
token := string(buffer.Bytes())
|
||||
return token, true
|
||||
}
|
||||
|
||||
// cleanup destroys the secure enclave, must be called with lock held
|
||||
func (c *jwtCache) cleanup() {
|
||||
if c.enclave != nil {
|
||||
c.enclave = nil
|
||||
}
|
||||
c.expiresAt = time.Time{}
|
||||
}
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
|
||||
@@ -46,6 +46,9 @@ const (
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
// JWT token cache TTL for the client daemon (disabled by default)
|
||||
defaultJWTCacheTTL = 0
|
||||
|
||||
errRestoreResidualState = "failed to restore residual state: %v"
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||
@@ -81,6 +84,8 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,6 +379,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.DisableNotifications = msg.DisableNotifications
|
||||
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
config.BlockInbound = msg.BlockInbound
|
||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForwarding
|
||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForwarding
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
if msg.SshJWTCacheTTL != nil {
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
@@ -493,7 +510,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous oauth flow info")
|
||||
return &proto.LoginResponse{
|
||||
@@ -510,7 +527,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||
return nil, err
|
||||
@@ -1065,12 +1082,235 @@ func (s *Server) Status(
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
return &statusResponse, nil
|
||||
}
|
||||
|
||||
// getSSHServerState retrieves the current SSH server state including enabled status and active sessions
|
||||
func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetSSHServerStatus()
|
||||
sshServerState := &proto.SSHServerState{
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
sshServerState.Sessions = append(sshServerState.Sessions, &proto.SSHSessionInfo{
|
||||
Username: session.Username,
|
||||
RemoteAddress: session.RemoteAddress,
|
||||
Command: session.Command,
|
||||
JwtUsername: session.JWTUsername,
|
||||
})
|
||||
}
|
||||
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
req *proto.GetPeerSSHHostKeyRequest,
|
||||
) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
peerAddress := req.GetPeerAddress()
|
||||
hostKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
|
||||
response := &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
}
|
||||
|
||||
if !found {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
response.SshHostKey = hostKey
|
||||
|
||||
if statusRecorder == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
response.PeerIP = peerState.IP
|
||||
response.PeerFQDN = peerState.FQDN
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// getJWTCacheTTL returns the JWT cache TTL from config or default (disabled)
|
||||
func (s *Server) getJWTCacheTTL() time.Duration {
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil || config.SSHJWTCacheTTL == nil {
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
seconds := *config.SSHJWTCacheTTL
|
||||
if seconds == 0 {
|
||||
log.Debug("SSH JWT cache disabled (configured to 0)")
|
||||
return 0
|
||||
}
|
||||
|
||||
ttl := time.Duration(seconds) * time.Second
|
||||
log.Debugf("SSH JWT cache TTL set to %v from config", ttl)
|
||||
return ttl
|
||||
}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
func (s *Server) RequestJWTAuth(
|
||||
ctx context.Context,
|
||||
msg *proto.RequestJWTAuthRequest,
|
||||
) (*proto.RequestJWTAuthResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||
}
|
||||
|
||||
jwtCacheTTL := s.getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
if cachedToken, found := s.jwtCache.get(); found {
|
||||
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: cachedToken,
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
|
||||
if hint == "" {
|
||||
hint = profilemanager.GetLoginHint()
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop, hint)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.flow = oAuthFlow
|
||||
s.oauthAuthFlow.info = authInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
DeviceCode: authInfo.DeviceCode,
|
||||
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
func (s *Server) WaitJWTToken(
|
||||
ctx context.Context,
|
||||
req *proto.WaitJWTTokenRequest,
|
||||
) (*proto.WaitJWTTokenResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
oAuthFlow := s.oauthAuthFlow.flow
|
||||
authInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
|
||||
}
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
|
||||
jwtCacheTTL := s.getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
s.jwtCache.store(token, jwtCacheTTL)
|
||||
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||
} else {
|
||||
log.Debug("JWT caching disabled, not storing token")
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
s.mutex.Unlock()
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: tokenInfo.GetTokenToUse(),
|
||||
TokenType: tokenInfo.TokenType,
|
||||
ExpiresIn: int64(tokenInfo.ExpiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
@@ -1136,25 +1376,61 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableServerRoutes := cfg.DisableServerRoutes
|
||||
blockLANAccess := cfg.BlockLANAccess
|
||||
|
||||
enableSSHRoot := false
|
||||
if cfg.EnableSSHRoot != nil {
|
||||
enableSSHRoot = *cfg.EnableSSHRoot
|
||||
}
|
||||
|
||||
enableSSHSFTP := false
|
||||
if cfg.EnableSSHSFTP != nil {
|
||||
enableSSHSFTP = *cfg.EnableSSHSFTP
|
||||
}
|
||||
|
||||
enableSSHLocalPortForwarding := false
|
||||
if cfg.EnableSSHLocalPortForwarding != nil {
|
||||
enableSSHLocalPortForwarding = *cfg.EnableSSHLocalPortForwarding
|
||||
}
|
||||
|
||||
enableSSHRemotePortForwarding := false
|
||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||
enableSSHRemotePortForwarding = *cfg.EnableSSHRemotePortForwarding
|
||||
}
|
||||
|
||||
disableSSHAuth := false
|
||||
if cfg.DisableSSHAuth != nil {
|
||||
disableSSHAuth = *cfg.DisableSSHAuth
|
||||
}
|
||||
|
||||
sshJWTCacheTTL := int32(0)
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||
}
|
||||
|
||||
return &proto.GetConfigResponse{
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
EnableSSHRoot: enableSSHRoot,
|
||||
EnableSSHSFTP: enableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1385,6 +1661,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -316,7 +317,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -72,6 +72,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
lazyConnectionEnabled := true
|
||||
blockInbound := true
|
||||
mtu := int64(1280)
|
||||
sshJWTCacheTTL := int32(300)
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
@@ -102,6 +103,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
||||
}
|
||||
|
||||
_, err = s.SetConfig(ctx, req)
|
||||
@@ -146,6 +148,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
||||
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
||||
require.Equal(t, uint16(mtu), cfg.MTU)
|
||||
require.NotNil(t, cfg.SSHJWTCacheTTL)
|
||||
require.Equal(t, int(sshJWTCacheTTL), *cfg.SSHJWTCacheTTL)
|
||||
|
||||
verifyAllFieldsCovered(t, req)
|
||||
}
|
||||
@@ -167,30 +171,36 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
"EnableSSHRoot": true,
|
||||
"EnableSSHSFTP": true,
|
||||
"EnableSSHLocalPortForwarding": true,
|
||||
"EnableSSHRemotePortForwarding": true,
|
||||
"DisableSSHAuth": true,
|
||||
"SshJWTCacheTTL": true,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
@@ -221,29 +231,35 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
// Map of CLI flag names to their corresponding SetConfigRequest field names.
|
||||
// This map must be updated when adding new config-related CLI flags.
|
||||
flagToField := map[string]string{
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
"enable-ssh-root": "EnableSSHRoot",
|
||||
"enable-ssh-sftp": "EnableSSHSFTP",
|
||||
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForwarding",
|
||||
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForwarding",
|
||||
"disable-ssh-auth": "DisableSSHAuth",
|
||||
"ssh-jwt-cache-ttl": "SshJWTCacheTTL",
|
||||
}
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&nftables.ShutdownState{})
|
||||
mgr.RegisterState(&iptables.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client to simplify usage
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
// Close closes the wrapped SSH Client
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// OpenTerminal starts an interactive terminal session with the remote SSH server
|
||||
func (c *Client) OpenTerminal() error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open new session: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
fd := int(os.Stdout.Fd())
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run raw terminal: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := term.Restore(fd, state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("terminal get size: %s", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("failed requesting pty session with xterm: %s", err)
|
||||
}
|
||||
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("failed to start login shell on the remote host: %s", err)
|
||||
}
|
||||
|
||||
if err := session.Wait(); err != nil {
|
||||
if e, ok := err.(*ssh.ExitError); ok {
|
||||
if e.ExitStatus() == 130 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed running SSH session: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
|
||||
func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 5 * time.Second,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
|
||||
}
|
||||
|
||||
return Dial("tcp", addr, config)
|
||||
}
|
||||
|
||||
// Dial connects to the remote SSH server.
|
||||
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
client, err := ssh.Dial(network, addr, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
699
client/ssh/client/client.go
Normal file
699
client/ssh/client/client.go
Normal file
@@ -0,0 +1,699 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDaemonAddr is the default address for the NetBird daemon
|
||||
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
|
||||
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
|
||||
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client for simplified SSH operations
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
terminalState *term.State
|
||||
terminalFd int
|
||||
|
||||
windowsStdoutMode uint32 // nolint:unused
|
||||
windowsStdinMode uint32 // nolint:unused
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
log.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
return c.waitForSession(ctx, session)
|
||||
}
|
||||
|
||||
// setupSessionIO connects session streams to local terminal
|
||||
func (c *Client) setupSessionIO(session *ssh.Session) {
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
}
|
||||
|
||||
// waitForSession waits for the session to complete with context cancellation
|
||||
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return c.handleSessionError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionError processes session termination errors
|
||||
func (c *Client) handleSessionError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return fmt.Errorf("session wait: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreTerminal restores the terminal to its original state
|
||||
func (c *Client) restoreTerminal() {
|
||||
if c.terminalState != nil {
|
||||
_ = term.Restore(c.terminalFd, c.terminalState)
|
||||
c.terminalState = nil
|
||||
c.terminalFd = 0
|
||||
}
|
||||
|
||||
if err := c.restoreWindowsConsoleState(); err != nil {
|
||||
log.Debugf("restore Windows console state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a command on the remote host and returns the output
|
||||
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
output, err := session.CombinedOutput(command)
|
||||
if err != nil {
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return output, fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
|
||||
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
|
||||
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return fmt.Errorf("setup terminal mode: %w", err)
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandError processes command execution errors
|
||||
func (c *Client) handleCommandError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if errors.As(err, &e) || errors.As(err, &em) {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
|
||||
// setupContextCancellation sets up context cancellation for a session
|
||||
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
_ = session.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
|
||||
// createSession creates a new SSH session with context cancellation setup
|
||||
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
cancel := c.setupContextCancellation(ctx, session)
|
||||
cleanup := func() {
|
||||
cancel()
|
||||
_ = session.Close()
|
||||
}
|
||||
|
||||
return session, cleanup, nil
|
||||
}
|
||||
|
||||
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
||||
func getDefaultDaemonAddr() string {
|
||||
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
||||
return addr
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultDaemonAddrWindows
|
||||
}
|
||||
return DefaultDaemonAddr
|
||||
}
|
||||
|
||||
// DialOptions contains options for SSH connections
|
||||
type DialOptions struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
DaemonAddr string
|
||||
SkipCachedToken bool
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
// Dial connects to the given ssh server with specified options
|
||||
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||
daemonAddr := opts.DaemonAddr
|
||||
if daemonAddr == "" {
|
||||
daemonAddr = getDefaultDaemonAddr()
|
||||
}
|
||||
opts.DaemonAddr = daemonAddr
|
||||
|
||||
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
}
|
||||
|
||||
if opts.IdentityFile != "" {
|
||||
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SSH key auth: %w", err)
|
||||
}
|
||||
config.Auth = append(config.Auth, authMethod)
|
||||
}
|
||||
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||
}
|
||||
|
||||
// dialSSH establishes an SSH connection without JWT authentication
|
||||
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Debugf("connection close after handshake failure: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
client := ssh.NewClient(clientConn, chans, reqs)
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
return dialSSH(ctx, network, addr, config)
|
||||
}
|
||||
|
||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||
}
|
||||
|
||||
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
||||
return dialSSH(ctx, network, addr, configWithJWT)
|
||||
}
|
||||
|
||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
|
||||
}
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
|
||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||
func getKnownHostsFiles() []string {
|
||||
var files []string
|
||||
|
||||
// User's known_hosts file (highest priority)
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
|
||||
files = append(files, userKnownHosts)
|
||||
}
|
||||
|
||||
// NetBird managed known_hosts files
|
||||
if runtime.GOOS == "windows" {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
|
||||
files = append(files, netbirdKnownHosts)
|
||||
} else {
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts")
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
// createHostKeyCallback creates a host key verification callback
|
||||
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
if opts.InsecureSkipVerify {
|
||||
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
if daemonAddr == "" {
|
||||
return fmt.Errorf("no daemon address")
|
||||
}
|
||||
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
|
||||
}
|
||||
|
||||
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
|
||||
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
|
||||
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
|
||||
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
|
||||
}
|
||||
|
||||
func getKnownHostsFilesList(knownHostsFile string) []string {
|
||||
if knownHostsFile != "" {
|
||||
return []string{knownHostsFile}
|
||||
}
|
||||
return getKnownHostsFiles()
|
||||
}
|
||||
|
||||
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
return hostKeyCallbacks
|
||||
}
|
||||
|
||||
// createSSHKeyAuth creates SSH key authentication from a private key file
|
||||
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
|
||||
keyData, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse SSH private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
|
||||
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
|
||||
localListener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", localAddr, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
localConn, err := localListener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
channel, err := c.client.Dial("tcp", remoteAddr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "administratively prohibited") {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
|
||||
} else {
|
||||
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("local forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("local forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
|
||||
host, port, err := c.parseRemoteAddress(remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
req := c.buildTCPIPForwardRequest(host, port)
|
||||
if err := c.sendTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("setup remote forward: %w", err)
|
||||
}
|
||||
|
||||
go c.handleRemoteForwardChannels(ctx, localAddr)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if err := c.cancelTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("cancel tcpip-forward: %w", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// parseRemoteAddress parses host and port from remote address string
|
||||
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
|
||||
host, portStr, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
return host, uint32(port), nil
|
||||
}
|
||||
|
||||
// buildTCPIPForwardRequest creates a tcpip-forward request message
|
||||
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
|
||||
return tcpipForwardMsg{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
|
||||
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send tcpip-forward request: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancelTCPIPForwardRequest cancels the tcpip-forward request
|
||||
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
|
||||
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
|
||||
// Get the channel once - subsequent calls return nil!
|
||||
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
|
||||
if channelRequests == nil {
|
||||
log.Debugf("forwarded-tcpip channel type already being handled")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan := <-channelRequests:
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("remote forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("remote forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
512
client/ssh/client/client_test.go
Normal file
512
client/ssh/client/client_test.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and start server
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Test Dial
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Verify client is connected
|
||||
assert.NotNil(t, client.client)
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate client key for multiple connections
|
||||
|
||||
const numClients = 3
|
||||
clients := make([]*Client, numClients)
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
for i := 0; i < numClients; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
cancel()
|
||||
require.NoError(t, err, "Client %d should connect successfully", i)
|
||||
clients[i] = client
|
||||
}
|
||||
|
||||
for i, client := range clients {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "Client %d should close without error", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if client != nil {
|
||||
require.NoError(t, client.Close(), "Client should close without error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_TerminalState(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Nil(t, client.terminalState)
|
||||
assert.Equal(t, 0, client.terminalFd)
|
||||
|
||||
client.restoreTerminal()
|
||||
assert.Nil(t, client.terminalState)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.OpenTerminal(ctx)
|
||||
// In test environment without a real terminal, this may complete quickly or timeout
|
||||
// Both behaviors are acceptable for testing terminal state management
|
||||
if err != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "console"),
|
||||
"Should timeout or have console error on Windows")
|
||||
} else {
|
||||
// On Unix systems in test environment, we may get various errors
|
||||
// including timeouts or terminal-related errors
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "terminal") ||
|
||||
strings.Contains(err.Error(), "pty"),
|
||||
"Expected timeout or terminal-related error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return server, serverAddr, client
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwarding(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("local forwarding times out gracefully", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "connection"),
|
||||
"Expected context or connection error")
|
||||
})
|
||||
|
||||
t.Run("remote forwarding denied", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "denied") ||
|
||||
strings.Contains(err.Error(), "disabled"),
|
||||
"Should be denied by default")
|
||||
})
|
||||
|
||||
t.Run("invalid addresses fail", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping data transfer test in short mode")
|
||||
}
|
||||
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Port forwarding requires the actual current user, not test user
|
||||
realUser, err := getRealCurrentUser()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Skip if running as system account that can't do port forwarding
|
||||
if testutil.IsSystemAccount(realUser) {
|
||||
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
||||
}
|
||||
|
||||
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
|
||||
InsecureSkipVerify: true, // Skip host key verification for test
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := testServer.Close(); err != nil {
|
||||
t.Logf("test server close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "Hello, World!"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("connection read error: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("connection write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
if err := localListener.Close(); err != nil {
|
||||
t.Logf("local listener close error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
err := client.LocalPortForward(ctx, localAddr, testServerAddr)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
if isWindowsPrivilegeError(err) {
|
||||
t.Logf("Port forward failed due to Windows privilege restrictions: %v", err)
|
||||
} else {
|
||||
t.Logf("Port forward error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Logf("set read deadline error: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(expectedResponse), n)
|
||||
assert.Equal(t, expectedResponse, string(response))
|
||||
}
|
||||
|
||||
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
||||
func getRealCurrentUser() (string, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
}
|
||||
|
||||
if username := os.Getenv("USER"); username != "" {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unable to determine current user")
|
||||
}
|
||||
|
||||
// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions
|
||||
func isWindowsPrivilegeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD
|
||||
strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess)
|
||||
strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser)
|
||||
strings.Contains(errStr, "privilege") ||
|
||||
strings.Contains(errStr, "access denied") ||
|
||||
strings.Contains(errStr, "user authentication failed")
|
||||
}
|
||||
127
client/ssh/client/terminal_unix.go
Normal file
127
client/ssh/client/terminal_unix.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build !windows
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
|
||||
stdinFd := int(os.Stdin.Fd())
|
||||
|
||||
if !term.IsTerminal(stdinFd) {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
if err := c.setupTerminal(session, fd); err != nil {
|
||||
if restoreErr := term.Restore(fd, state); restoreErr != nil {
|
||||
log.Debugf("restore terminal state: %v", restoreErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.terminalState = state
|
||||
c.terminalFd = fd
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
|
||||
go func() {
|
||||
defer signal.Stop(sigChan)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
case sig := <-sigChan:
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
signal.Reset(sig)
|
||||
s, ok := sig.(syscall.Signal)
|
||||
if !ok {
|
||||
log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(syscall.Getpid(), s); err != nil {
|
||||
log.Debugf("kill process with signal %v: %v", s, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreWindowsConsoleState is a no-op on Unix systems
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get terminal size: %w", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
// Ctrl+C
|
||||
ssh.VINTR: 3,
|
||||
// Ctrl+\
|
||||
ssh.VQUIT: 28,
|
||||
// Backspace
|
||||
ssh.VERASE: 127,
|
||||
// Ctrl+U
|
||||
ssh.VKILL: 21,
|
||||
// Ctrl+D
|
||||
ssh.VEOF: 4,
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
// Ctrl+Q
|
||||
ssh.VSTART: 17,
|
||||
// Ctrl+S
|
||||
ssh.VSTOP: 19,
|
||||
// Ctrl+Z
|
||||
ssh.VSUSP: 26,
|
||||
// Ctrl+O
|
||||
ssh.VDISCARD: 15,
|
||||
// Ctrl+R
|
||||
ssh.VREPRINT: 18,
|
||||
// Ctrl+W
|
||||
ssh.VWERASE: 23,
|
||||
// Ctrl+V
|
||||
ssh.VLNEXT: 22,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
265
client/ssh/client/terminal_windows.go
Normal file
265
client/ssh/client/terminal_windows.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
enableProcessedInput = 0x0001
|
||||
enableLineInput = 0x0002
|
||||
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
|
||||
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
|
||||
enableVirtualTerminalInput = 0x0200
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
||||
)
|
||||
|
||||
// ConsoleUnavailableError indicates that Windows console handles are not available
|
||||
// (e.g., in CI environments where stdout/stdin are redirected)
|
||||
type ConsoleUnavailableError struct {
|
||||
Operation string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Error() string {
|
||||
return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type coord struct {
|
||||
x, y int16
|
||||
}
|
||||
|
||||
type smallRect struct {
|
||||
left, top, right, bottom int16
|
||||
}
|
||||
|
||||
type consoleScreenBufferInfo struct {
|
||||
size coord
|
||||
cursorPosition coord
|
||||
attributes uint16
|
||||
window smallRect
|
||||
maximumWindowSize coord
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
if err := c.saveWindowsConsoleState(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("console unavailable, not requesting PTY: %v", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("save console state: %w", err)
|
||||
}
|
||||
|
||||
if err := c.enableWindowsVirtualTerminal(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("virtual terminal unavailable: %v", err)
|
||||
} else {
|
||||
return fmt.Errorf("failed to enable virtual terminal: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
w, h := c.getWindowsConsoleSize()
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.ICRNL: 1,
|
||||
ssh.OPOST: 1,
|
||||
ssh.ONLCR: 1,
|
||||
ssh.ISIG: 1,
|
||||
ssh.ICANON: 1,
|
||||
ssh.VINTR: 3, // Ctrl+C
|
||||
ssh.VQUIT: 28, // Ctrl+\
|
||||
ssh.VERASE: 127, // Backspace
|
||||
ssh.VKILL: 21, // Ctrl+U
|
||||
ssh.VEOF: 4, // Ctrl+D
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
ssh.VSTART: 17, // Ctrl+Q
|
||||
ssh.VSTOP: 19, // Ctrl+S
|
||||
ssh.VSUSP: 26, // Ctrl+Z
|
||||
ssh.VDISCARD: 15, // Ctrl+O
|
||||
ssh.VWERASE: 23, // Ctrl+W
|
||||
ssh.VLNEXT: 22, // Ctrl+V
|
||||
ssh.VREPRINT: 18, // Ctrl+R
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
|
||||
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
|
||||
log.Debugf("restore Windows console state: %v", restoreErr)
|
||||
}
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) saveWindowsConsoleState() error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in saveWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
var stdoutMode, stdinMode uint32
|
||||
|
||||
ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdout console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdin console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 1
|
||||
c.windowsStdoutMode = stdoutMode
|
||||
c.windowsStdinMode = stdinMode
|
||||
|
||||
log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) enableWindowsVirtualTerminal() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
var mode uint32
|
||||
|
||||
ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode |= enableVirtualTerminalProcessing
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "enable virtual terminal processing",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
|
||||
mode |= enableVirtualTerminalInput
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "set stdin raw mode",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("enabled Windows virtual terminal processing")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getWindowsConsoleSize() (int, int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in getWindowsConsoleSize: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
var csbi consoleScreenBufferInfo
|
||||
|
||||
ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get console buffer info, using defaults: %v", err)
|
||||
return 80, 24
|
||||
}
|
||||
|
||||
width := int(csbi.window.right - csbi.window.left + 1)
|
||||
height := int(csbi.window.bottom - csbi.window.top + 1)
|
||||
|
||||
log.Debugf("Windows console size: %dx%d", width, height)
|
||||
return width, height
|
||||
}
|
||||
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
var err error
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if c.terminalFd != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdout console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdout console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdin console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdin console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 0
|
||||
c.windowsStdoutMode = 0
|
||||
c.windowsStdinMode = 0
|
||||
|
||||
log.Debugf("restored Windows console state")
|
||||
return err
|
||||
}
|
||||
171
client/ssh/common.go
Normal file
171
client/ssh/common.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
NetBirdSSHConfigFile = "99-netbird.conf"
|
||||
|
||||
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
||||
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPeerNotFound indicates the peer was not found in the network
|
||||
ErrPeerNotFound = errors.New("peer not found in network")
|
||||
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
||||
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
||||
)
|
||||
|
||||
// HostKeyVerifier provides SSH host key verification
|
||||
type HostKeyVerifier interface {
|
||||
VerifySSHHostKey(peerAddress string, key []byte) error
|
||||
}
|
||||
|
||||
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
||||
type DaemonHostKeyVerifier struct {
|
||||
client proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
||||
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
||||
return &DaemonHostKeyVerifier{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
||||
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: peerAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
return ErrPeerNotFound
|
||||
}
|
||||
|
||||
storedKeyData := response.GetSshHostKey()
|
||||
|
||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||
}
|
||||
|
||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
|
||||
req := &proto.RequestJWTAuthRequest{}
|
||||
if hint != "" {
|
||||
req.Hint = &hint
|
||||
}
|
||||
authResponse, err := client.RequestJWTAuth(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||
}
|
||||
|
||||
if useCache && authResponse.CachedToken != "" {
|
||||
log.Debug("Using cached authentication token")
|
||||
return authResponse.CachedToken, nil
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||
DeviceCode: authResponse.DeviceCode,
|
||||
UserCode: authResponse.UserCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for JWT token: %w", err)
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
||||
}
|
||||
return tokenResponse.Token, nil
|
||||
}
|
||||
|
||||
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
||||
// Returns nil only if the presented key matches the stored key.
|
||||
// Returns ErrNoStoredKey if storedKeyData is empty.
|
||||
// Returns an error if the keys don't match or if parsing fails.
|
||||
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
||||
if len(storedKeyData) == 0 {
|
||||
return ErrNoStoredKey
|
||||
}
|
||||
|
||||
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
||||
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
||||
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
||||
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
||||
configWithJWT := *config
|
||||
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
||||
return &configWithJWT
|
||||
}
|
||||
|
||||
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
||||
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
||||
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
presentedKey := key.Marshal()
|
||||
|
||||
for _, addr := range addresses {
|
||||
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
||||
if errors.Is(err, ErrPeerNotFound) {
|
||||
// Try other addresses for this peer
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Verified
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// buildAddressList creates a list of addresses to check for host key verification.
|
||||
// It includes the original hostname and extracts the host part from the remote address if different.
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
282
client/ssh/config/manager.go
Normal file
282
client/ssh/config/manager.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
||||
|
||||
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
||||
|
||||
MaxPeersForSSHConfig = 200
|
||||
|
||||
fileWriteTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func isSSHConfigDisabled() bool {
|
||||
value := os.Getenv(EnvDisableSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
func isSSHConfigForced() bool {
|
||||
value := os.Getenv(EnvForceSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
forced, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return forced
|
||||
}
|
||||
|
||||
// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count
|
||||
func shouldGenerateSSHConfig(peerCount int) bool {
|
||||
if isSSHConfigDisabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
if isSSHConfigForced() {
|
||||
return true
|
||||
}
|
||||
|
||||
return peerCount <= MaxPeersForSSHConfig
|
||||
}
|
||||
|
||||
// writeFileWithTimeout writes data to a file with a timeout
|
||||
func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- os.WriteFile(filename, data, perm)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Manager handles SSH client configuration for NetBird peers
|
||||
type Manager struct {
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
}
|
||||
|
||||
// PeerSSHInfo represents a peer's SSH configuration information
|
||||
type PeerSSHInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
FQDN string
|
||||
}
|
||||
|
||||
// New creates a new SSH config manager
|
||||
func New() *Manager {
|
||||
sshConfigDir := getSystemSSHConfigDir()
|
||||
return &Manager{
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: nbssh.NetBirdSSHConfigFile,
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
|
||||
func getSystemSSHConfigDir() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return getWindowsSSHConfigDir()
|
||||
}
|
||||
return nbssh.UnixSSHConfigDir
|
||||
}
|
||||
|
||||
func getWindowsSSHConfigDir() string {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
|
||||
}
|
||||
|
||||
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
||||
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
|
||||
if !shouldGenerateSSHConfig(len(peers)) {
|
||||
m.logSkipReason(len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
sshConfig, err := m.buildSSHConfig(peers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build SSH config: %w", err)
|
||||
}
|
||||
return m.writeSSHConfig(sshConfig)
|
||||
}
|
||||
|
||||
func (m *Manager) logSkipReason(peerCount int) {
|
||||
if isSSHConfigDisabled() {
|
||||
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
|
||||
} else {
|
||||
log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.",
|
||||
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
|
||||
sshConfig := m.buildConfigHeader()
|
||||
|
||||
var allHostPatterns []string
|
||||
for _, peer := range peers {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
allHostPatterns = append(allHostPatterns, hostPatterns...)
|
||||
}
|
||||
|
||||
if len(allHostPatterns) > 0 {
|
||||
peerConfig, err := m.buildPeerConfig(allHostPatterns)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sshConfig += peerConfig
|
||||
}
|
||||
|
||||
return sshConfig, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildConfigHeader() string {
|
||||
return "# NetBird SSH client configuration\n" +
|
||||
"# Generated automatically - do not edit manually\n" +
|
||||
"#\n" +
|
||||
"# To disable SSH config management, use:\n" +
|
||||
"# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" +
|
||||
"#\n\n"
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
uniquePatterns := make(map[string]bool)
|
||||
var deduplicatedPatterns []string
|
||||
for _, pattern := range allHostPatterns {
|
||||
if !uniquePatterns[pattern] {
|
||||
uniquePatterns[pattern] = true
|
||||
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
execPath, err := m.getNetBirdExecutablePath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
var hostPatterns []string
|
||||
if peer.IP != "" {
|
||||
hostPatterns = append(hostPatterns, peer.IP)
|
||||
}
|
||||
if peer.FQDN != "" {
|
||||
hostPatterns = append(hostPatterns, peer.FQDN)
|
||||
}
|
||||
if peer.Hostname != "" && peer.Hostname != peer.FQDN {
|
||||
hostPatterns = append(hostPatterns, peer.Hostname)
|
||||
}
|
||||
return hostPatterns
|
||||
}
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSSHClientConfig removes NetBird SSH configuration
|
||||
func (m *Manager) RemoveSSHClientConfig() error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
err := os.Remove(sshConfigPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
|
||||
}
|
||||
if err == nil {
|
||||
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) getNetBirdExecutablePath() (string, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("retrieve executable path: %w", err)
|
||||
}
|
||||
|
||||
realPath, err := filepath.EvalSymlinks(execPath)
|
||||
if err != nil {
|
||||
log.Debugf("symlink resolution failed: %v", err)
|
||||
return execPath, nil
|
||||
}
|
||||
|
||||
return realPath, nil
|
||||
}
|
||||
|
||||
// GetSSHConfigDir returns the SSH config directory path
|
||||
func (m *Manager) GetSSHConfigDir() string {
|
||||
return m.sshConfigDir
|
||||
}
|
||||
|
||||
// GetSSHConfigFile returns the SSH config file name
|
||||
func (m *Manager) GetSSHConfigFile() string {
|
||||
return m.sshConfigFile
|
||||
}
|
||||
159
client/ssh/config/manager_test.go
Normal file
159
client/ssh/config/manager_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Test SSH config generation with peers
|
||||
peers := []PeerSSHInfo{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read generated config
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
configStr := string(content)
|
||||
|
||||
// Verify the basic SSH config structure exists
|
||||
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
||||
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
||||
|
||||
// Check that peer hostnames are included
|
||||
assert.Contains(t, configStr, "100.125.1.1")
|
||||
assert.Contains(t, configStr, "100.125.1.2")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
assert.Contains(t, configStr, "peer2.nb.internal")
|
||||
|
||||
// Check platform-specific UserKnownHostsFile
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
|
||||
} else {
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSystemSSHConfigDir(t *testing.T) {
|
||||
configDir := getSystemSSHConfigDir()
|
||||
|
||||
// Path should not be empty
|
||||
assert.NotEmpty(t, configDir)
|
||||
|
||||
// Should be an absolute path
|
||||
assert.True(t, filepath.IsAbs(configDir))
|
||||
|
||||
// On Unix systems, should start with /etc
|
||||
// On Windows, should contain ProgramData
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
||||
} else {
|
||||
assert.Contains(t, configDir, "/etc/ssh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PeerLimit(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is skipped when too many peers
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should not be created due to peer limit
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
// Set force environment variable
|
||||
t.Setenv(EnvForceSSHConfig, "true")
|
||||
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is forced despite many peers
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should be created despite peer limit due to force flag
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err, "SSH config should be created when forced")
|
||||
|
||||
// Verify config contains peer hostnames
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
configStr := string(content)
|
||||
assert.Contains(t, configStr, "peer0.nb.internal")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
}
|
||||
22
client/ssh/config/shutdown_state.go
Normal file
22
client/ssh/config/shutdown_state.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package config
|
||||
|
||||
// ShutdownState represents SSH configuration state that needs to be cleaned up.
|
||||
type ShutdownState struct {
|
||||
SSHConfigDir string
|
||||
SSHConfigFile string
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "ssh_config_state"
|
||||
}
|
||||
|
||||
// Cleanup removes SSH client configuration files.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager := &Manager{
|
||||
sshConfigDir: s.SSHConfigDir,
|
||||
sshConfigFile: s.SSHConfigFile,
|
||||
}
|
||||
|
||||
return manager.RemoveSSHClientConfig()
|
||||
}
|
||||
99
client/ssh/detection/detection.go
Normal file
99
client/ssh/detection/detection.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// ServerIdentifier is the base response for NetBird SSH servers
|
||||
ServerIdentifier = "NetBird-SSH-Server"
|
||||
// ProxyIdentifier is the base response for NetBird SSH proxy
|
||||
ProxyIdentifier = "NetBird-SSH-Proxy"
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
|
||||
const (
|
||||
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
|
||||
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
|
||||
ServerTypeRegular ServerType = "regular"
|
||||
)
|
||||
|
||||
// Dialer provides network connection capabilities
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// RequiresJWT checks if the server type requires JWT authentication
|
||||
func (s ServerType) RequiresJWT() bool {
|
||||
return s == ServerTypeNetBirdJWT
|
||||
}
|
||||
|
||||
// ExitCode returns the exit code for the detect command
|
||||
func (s ServerType) ExitCode() int {
|
||||
switch s {
|
||||
case ServerTypeNetBirdJWT:
|
||||
return 0
|
||||
case ServerTypeNetBirdNoJWT:
|
||||
return 1
|
||||
case ServerTypeRegular:
|
||||
return 2
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// DetectSSHServerType detects SSH server type using the provided dialer
|
||||
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
|
||||
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
log.Debugf("SSH server banner: %s", serverBanner)
|
||||
|
||||
if !strings.HasPrefix(serverBanner, "SSH-") {
|
||||
log.Debugf("Invalid SSH banner")
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(serverBanner, ServerIdentifier) {
|
||||
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if strings.Contains(serverBanner, JWTRequiredMarker) {
|
||||
return ServerTypeNetBirdJWT, nil
|
||||
}
|
||||
|
||||
return ServerTypeNetBirdNoJWT, nil
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func isRoot() bool {
|
||||
return os.Geteuid() == 0
|
||||
}
|
||||
|
||||
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
|
||||
if !isRoot() {
|
||||
shell := getUserShell(user)
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
|
||||
return shell, []string{"-l"}, nil
|
||||
}
|
||||
|
||||
loginPath, err = exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", user, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
case "darwin":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
|
||||
case "freebsd":
|
||||
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !darwin
|
||||
// +build !darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import "os/user"
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
return user.Lookup(username)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
var userObject *user.User
|
||||
userObject, err := user.Lookup(username)
|
||||
if err != nil && err.Error() == user.UnknownUserError(username).Error() {
|
||||
return idUserNameLookup(username)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userObject, nil
|
||||
}
|
||||
|
||||
func idUserNameLookup(username string) (*user.User, error) {
|
||||
cmd := exec.Command("id", "-P", username)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err)
|
||||
}
|
||||
colon := ":"
|
||||
|
||||
if !bytes.Contains(out, []byte(username+colon)) {
|
||||
return nil, fmt.Errorf("unable to find user in returned string")
|
||||
}
|
||||
// netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh
|
||||
parts := strings.SplitN(string(out), colon, 10)
|
||||
userObject := &user.User{
|
||||
Username: parts[0],
|
||||
Uid: parts[2],
|
||||
Gid: parts[3],
|
||||
Name: parts[7],
|
||||
HomeDir: parts[8],
|
||||
}
|
||||
return userObject, nil
|
||||
}
|
||||
392
client/ssh/proxy/proxy.go
Normal file
392
client/ssh/proxy/proxy.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
|
||||
sshConnectionTimeout = 120 * time.Second
|
||||
// sshHandshakeTimeout is the timeout for SSH handshake completion
|
||||
sshHandshakeTimeout = 30 * time.Second
|
||||
|
||||
jwtAuthErrorMsg = "JWT authentication: %w"
|
||||
)
|
||||
|
||||
type SSHProxy struct {
|
||||
daemonAddr string
|
||||
targetHost string
|
||||
targetPort int
|
||||
stderr io.Writer
|
||||
conn *grpc.ClientConn
|
||||
daemonClient proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
|
||||
return &SSHProxy{
|
||||
daemonAddr: daemonAddr,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
stderr: stderr,
|
||||
conn: grpcConn,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) Close() error {
|
||||
if p.conn != nil {
|
||||
return p.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
|
||||
if err != nil {
|
||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||
}
|
||||
|
||||
return p.runProxySSHServer(ctx, jwtToken)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
||||
|
||||
sshServer := &ssh.Server{
|
||||
Handler: func(s ssh.Session) {
|
||||
p.handleSSHSession(ctx, s, jwtToken)
|
||||
},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": p.directTCPIPHandler,
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": func(s ssh.Session) {
|
||||
p.sftpSubsystemHandler(s, jwtToken)
|
||||
},
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": p.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
|
||||
},
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
hostKey, err := generateHostKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
sshServer.HostSigners = []ssh.Signer{hostKey}
|
||||
|
||||
conn := &stdioConn{
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
sshServer.HandleConn(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = sshClient.Close() }()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
|
||||
log.Debugf("PTY request to backend: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
|
||||
log.Debugf("window change: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if len(session.Command()) > 0 {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err = serverSession.Shell(); err != nil {
|
||||
log.Debugf("start shell: %v", err)
|
||||
return
|
||||
}
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("session wait: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
|
||||
var exitErr *cryptossh.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
|
||||
log.Debugf("set exit status: %v", exitErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateHostKey() (ssh.Signer, error) {
|
||||
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ED25519 key: %w", err)
|
||||
}
|
||||
|
||||
signer, err := cryptossh.ParsePrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
type stdioConn struct {
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *stdioConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdin.Read(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdout.Write(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) LocalAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) RemoteAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
ctx, cancel := context.WithCancel(s.Context())
|
||||
defer cancel()
|
||||
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
log.Debugf("close SSH client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("close server session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
stdin, stdout, err := p.setupSFTPPipes(serverSession)
|
||||
if err != nil {
|
||||
log.Debugf("setup SFTP pipes: %v", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := serverSession.RequestSubsystem("sftp"); err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
|
||||
stdin, err := serverSession.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := serverSession.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
return stdin, stdout, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
|
||||
copyErrCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(stdin, s)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP client to server copy: %v", err)
|
||||
}
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("close stdin: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(s, stdout)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP server to client copy: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("force close server session on context cancellation: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("SFTP copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("SFTP session ended: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return false, []byte("port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
|
||||
Timeout: sshHandshakeTimeout,
|
||||
HostKeyCallback: p.verifyHostKey,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: sshConnectionTimeout,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to server: %w", err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
return cryptossh.NewClient(clientConn, chans, reqs), nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
367
client/ssh/proxy/proxy_test.go
Normal file
367
client/ssh/proxy/proxy_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" {
|
||||
if os.Args[2] == "exec" {
|
||||
if len(os.Args) > 3 {
|
||||
cmd := os.Args[3]
|
||||
if cmd == "echo" && len(os.Args) > 4 {
|
||||
fmt.Fprintln(os.Stdout, os.Args[4])
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||
t.Run("calls daemon to verify host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
testPubKey, err := nbssh.GeneratePublicKey(testKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey("test-host", testPubKey)
|
||||
|
||||
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("rejects unknown host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
jwtToken string
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
key, found := m.hostKeys[req.PeerAddress]
|
||||
return &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
SshHostKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockDaemon struct {
|
||||
addr string
|
||||
server *grpc.Server
|
||||
impl *mockDaemonServer
|
||||
}
|
||||
|
||||
func startMockDaemon(t *testing.T) *mockDaemon {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
impl := &mockDaemonServer{
|
||||
hostKeys: make(map[string][]byte),
|
||||
jwtToken: "test-jwt-token",
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
proto.RegisterDaemonServiceServer(grpcServer, impl)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
|
||||
return &mockDaemon{
|
||||
addr: listener.Addr().String(),
|
||||
server: grpcServer,
|
||||
impl: impl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||
m.impl.hostKeys[addr] = pubKey
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func (m *mockDaemon) stop() {
|
||||
if m.server != nil {
|
||||
m.server.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||
t.Helper()
|
||||
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
|
||||
require.NoError(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
// TerminalTimeout is the timeout for terminal session to be ready
|
||||
const TerminalTimeout = 10 * time.Second
|
||||
|
||||
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultSSHServer is a function that creates DefaultServer
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return newDefaultServer(hostKeyPEM, addr)
|
||||
}
|
||||
|
||||
// Server is an interface of SSH server
|
||||
type Server interface {
|
||||
// Stop stops SSH server.
|
||||
Stop() error
|
||||
// Start starts SSH server. Blocking
|
||||
Start() error
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
RemoveAuthorizedKey(peer string)
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
AddAuthorizedKey(peer, newKey string) error
|
||||
}
|
||||
|
||||
// DefaultServer is the embedded NetBird SSH server
|
||||
type DefaultServer struct {
|
||||
listener net.Listener
|
||||
// authorizedKeys is ssh pub key indexed by peer WireGuard public key
|
||||
authorizedKeys map[string]ssh.PublicKey
|
||||
mu sync.Mutex
|
||||
hostKeyPEM []byte
|
||||
sessions []ssh.Session
|
||||
}
|
||||
|
||||
// newDefaultServer creates new server with provided host key
|
||||
func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedKeys := make(map[string]ssh.PublicKey)
|
||||
return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil
|
||||
}
|
||||
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
func (srv *DefaultServer) RemoveAuthorizedKey(peer string) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
delete(srv.authorizedKeys, peer)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srv.authorizedKeys[peer] = parsedKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops SSH server.
|
||||
func (srv *DefaultServer) Stop() error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
err := srv.listener.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range srv.sessions {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing SSH session from %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
for _, allowed := range srv.authorizedKeys {
|
||||
if ssh.KeysEqual(allowed, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
return []string{
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
}
|
||||
}
|
||||
|
||||
func acceptEnv(s string) bool {
|
||||
split := strings.Split(s, "=")
|
||||
if len(split) != 2 {
|
||||
return false
|
||||
}
|
||||
return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
|
||||
}
|
||||
|
||||
// sessionHandler handles SSH session post auth
|
||||
func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
srv.mu.Lock()
|
||||
srv.sessions = append(srv.sessions, session)
|
||||
srv.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||
|
||||
localUser, err := userNameLookup(session.User())
|
||||
if err != nil {
|
||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||
err = session.Exit(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(loginCmd, loginArgs...)
|
||||
go func() {
|
||||
<-session.Context().Done()
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
err := cmd.Process.Kill()
|
||||
if err != nil {
|
||||
log.Debugf("failed killing SSH process %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
cmd.Env = append(cmd.Env, v)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Login command: %s", cmd.String())
|
||||
file, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
log.Errorf("failed starting SSH server: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinSize(file, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
|
||||
srv.stdInOut(file, session)
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, err := io.WriteString(session, "only PTY is supported.\n")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = session.Exit(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("SSH session ended")
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
go func() {
|
||||
// stdin
|
||||
_, err := io.Copy(file, session)
|
||||
if err != nil {
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||
timer := time.NewTimer(TerminalTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
default:
|
||||
// stdout
|
||||
writtenBytes, err := io.Copy(session, file)
|
||||
if err != nil && writtenBytes != 0 {
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(TerminalBackoffDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts SSH server. Blocking
|
||||
func (srv *DefaultServer) Start() error {
|
||||
log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())
|
||||
|
||||
publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler)
|
||||
hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM)
|
||||
err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserShell(userID string) string {
|
||||
if runtime.GOOS == "linux" {
|
||||
output, _ := exec.Command("getent", "passwd", userID).Output()
|
||||
line := strings.SplitN(string(output), ":", 10)
|
||||
if len(line) > 6 {
|
||||
return strings.TrimSpace(line[6])
|
||||
}
|
||||
}
|
||||
|
||||
shell := os.Getenv("SHELL")
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
return shell
|
||||
}
|
||||
206
client/ssh/server/command_execution.go
Normal file
206
client/ssh/server/command_execution.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
if hasPty {
|
||||
commandType = "Pty command"
|
||||
}
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
|
||||
if hasPty {
|
||||
errorMsg += " with Pty"
|
||||
}
|
||||
errorMsg += "\n"
|
||||
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !hasPty {
|
||||
if s.executeCommand(logger, session, execCmd, cleanup) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
defer cleanup()
|
||||
|
||||
ptyReq, _, _ := session.Pty()
|
||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, nil, errors.New("no user in privilege result")
|
||||
}
|
||||
|
||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||
if hasPty && !s.suSupportsPty {
|
||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, func() {}, nil
|
||||
}
|
||||
|
||||
// executeCommand executes the command and handles I/O and exit codes
|
||||
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
|
||||
defer cleanup()
|
||||
|
||||
s.setupProcessGroup(execCmd)
|
||||
|
||||
stdinPipe, err := execCmd.StdinPipe()
|
||||
if err != nil {
|
||||
logger.Errorf("create stdin pipe: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
execCmd.Stdout = session
|
||||
execCmd.Stderr = session.Stderr()
|
||||
|
||||
if execCmd.Dir != "" {
|
||||
if _, err := os.Stat(execCmd.Dir); err != nil {
|
||||
logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
|
||||
execCmd.Dir = "/"
|
||||
}
|
||||
}
|
||||
|
||||
if err := execCmd.Start(); err != nil {
|
||||
logger.Errorf("command start failed: %v", err)
|
||||
// no user message for exec failure, just exit
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
go s.handleCommandIO(logger, stdinPipe, session)
|
||||
return s.waitForCommandCleanup(logger, session, execCmd)
|
||||
}
|
||||
|
||||
// handleCommandIO manages stdin/stdout copying in a goroutine
|
||||
func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
|
||||
defer func() {
|
||||
if err := stdinPipe.Close(); err != nil {
|
||||
logger.Debugf("stdin pipe close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(stdinPipe, session); err != nil {
|
||||
logger.Debugf("stdin copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCommandCleanup waits for command completion with session disconnect handling
|
||||
func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session cancelled, terminating command")
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
logger.Tracef("command terminated after session cancellation: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
|
||||
case err := <-done:
|
||||
return s.handleCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandCompletion handles command completion
|
||||
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
|
||||
if err != nil {
|
||||
logger.Debugf("command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return false
|
||||
}
|
||||
|
||||
s.handleSessionExit(session, nil, logger)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleSessionExit handles command errors and sets appropriate exit codes
|
||||
func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
|
||||
if err == nil {
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
if err := session.Exit(exitError.ExitCode()); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("non-exit error in command execution: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
52
client/ssh/server/command_execution_js.go
Normal file
52
client/ssh/server/command_execution_js.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
return nil, nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
}
|
||||
|
||||
// killProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) killProcessGroup(*exec.Cmd) {
|
||||
}
|
||||
|
||||
// detectSuPtySupport always returns false on JS/WASM
|
||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty is not supported on JS/WASM
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
329
client/ssh/server/command_execution_unix.go
Normal file
329
client/ssh/server/command_execution_unix.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ptyManager manages Pty file operations with thread safety
|
||||
type ptyManager struct {
|
||||
file *os.File
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
closeErr error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newPtyManager(file *os.File) *ptyManager {
|
||||
return &ptyManager{file: file}
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Close() error {
|
||||
pm.once.Do(func() {
|
||||
pm.mu.Lock()
|
||||
pm.closed = true
|
||||
pm.closeErr = pm.file.Close()
|
||||
pm.mu.Unlock()
|
||||
})
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
return pm.closeErr
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
if pm.closed {
|
||||
return errors.New("pty is closed")
|
||||
}
|
||||
return pty.Setsize(pm.file, ws)
|
||||
}
|
||||
|
||||
func (pm *ptyManager) File() *os.File {
|
||||
return pm.file
|
||||
}
|
||||
|
||||
// detectSuPtySupport checks if su supports the --pty flag
|
||||
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "su", "--help")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Debugf("su --help failed (may not support --help): %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
supported := strings.Contains(string(output), "--pty")
|
||||
log.Debugf("su --pty support detected: %v", supported)
|
||||
return supported
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
args := []string{"-l"}
|
||||
if hasPty && s.suSupportsPty {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
args = append(args, localUser.Username, "-c", command)
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-l"}
|
||||
}
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
termType := ptyReq.Term
|
||||
if termType == "" {
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
errorMsg := "User switching failed - login command not available\r\n"
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
|
||||
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty start failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
ptyMgr := newPtyManager(ptmx)
|
||||
defer func() {
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
|
||||
s.handlePtyIO(logger, session, ptyMgr)
|
||||
s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
|
||||
winSize := &pty.Winsize{
|
||||
Cols: uint16(ptyReq.Window.Width),
|
||||
Rows: uint16(ptyReq.Window.Height),
|
||||
}
|
||||
if winSize.Cols == 0 {
|
||||
winSize.Cols = 80
|
||||
}
|
||||
if winSize.Rows == 0 {
|
||||
winSize.Rows = 24
|
||||
}
|
||||
|
||||
ptmx, err := pty.StartWithSize(execCmd, winSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start Pty: %w", err)
|
||||
}
|
||||
|
||||
return ptmx, nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
|
||||
for {
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case win, ok := <-winCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
|
||||
logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
|
||||
ptmx := ptyMgr.File()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(ptmx, session); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty input copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
|
||||
logger.Debugf("Pty session cancelled, terminating command")
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close during session cancellation: %v", err)
|
||||
}
|
||||
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
|
||||
} else {
|
||||
logger.Debugf("Pty command terminated after session cancellation")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
pgid := cmd.Process.Pid
|
||||
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
logger.Debugf("kill process group SIGTERM: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
const gracePeriod = 500 * time.Millisecond
|
||||
const checkInterval = 50 * time.Millisecond
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.After(gracePeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
||||
logger.Debugf("kill process group SIGKILL: %v", err)
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := syscall.Kill(-pgid, 0); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
430
client/ssh/server/command_execution_windows.go
Normal file
430
client/ssh/server/command_execution_windows.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/server/winpty"
|
||||
)
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
||||
}
|
||||
|
||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
|
||||
var env []string
|
||||
for key, value := range envMap {
|
||||
env = append(env, key+"="+value)
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// loadUserProfile loads the Windows user profile and returns the profile path.
|
||||
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
|
||||
usernamePtr, err := windows.UTF16PtrFromString(username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert username to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
var domainUTF16 *uint16
|
||||
if domain != "" && domain != "." {
|
||||
domainUTF16, err = windows.UTF16PtrFromString(domain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type profileInfo struct {
|
||||
dwSize uint32
|
||||
dwFlags uint32
|
||||
lpUserName *uint16
|
||||
lpProfilePath *uint16
|
||||
lpDefaultPath *uint16
|
||||
lpServerName *uint16
|
||||
lpPolicyPath *uint16
|
||||
hProfile windows.Handle
|
||||
}
|
||||
|
||||
const PI_NOUI = 0x00000001
|
||||
|
||||
profile := profileInfo{
|
||||
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
|
||||
dwFlags: PI_NOUI,
|
||||
lpUserName: usernamePtr,
|
||||
lpServerName: domainUTF16,
|
||||
}
|
||||
|
||||
userenv := windows.NewLazySystemDLL("userenv.dll")
|
||||
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
|
||||
|
||||
ret, _, err := loadUserProfileW.Call(
|
||||
uintptr(userToken),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
)
|
||||
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("LoadUserProfileW: %w", err)
|
||||
}
|
||||
|
||||
if profile.lpProfilePath == nil {
|
||||
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
|
||||
}
|
||||
|
||||
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
|
||||
return profilePath, nil
|
||||
}
|
||||
|
||||
// loadSystemEnvironment loads system-wide environment variables from registry.
|
||||
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
|
||||
registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open system environment registry key: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := key.Close(); err != nil {
|
||||
log.Debugf("close registry key: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.readRegistryEnvironment(key, envMap)
|
||||
}
|
||||
|
||||
// readRegistryEnvironment reads environment variables from a registry key.
|
||||
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
|
||||
names, err := key.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read registry value names: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
value, valueType, err := key.GetStringValue(name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to read registry value %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
finalValue := s.expandRegistryValue(value, valueType, name)
|
||||
s.setEnvironmentVariable(envMap, name, finalValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandRegistryValue expands registry values if they contain environment variables.
|
||||
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
|
||||
if valueType != registry.EXPAND_SZ {
|
||||
return value
|
||||
}
|
||||
|
||||
sourcePtr := windows.StringToUTF16Ptr(value)
|
||||
expandedBuffer := make([]uint16, 1024)
|
||||
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s: %v", name, err)
|
||||
return value
|
||||
}
|
||||
|
||||
// If buffer was too small, retry with larger buffer
|
||||
if expandedLen > uint32(len(expandedBuffer)) {
|
||||
expandedBuffer = make([]uint16, expandedLen)
|
||||
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
|
||||
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// setEnvironmentVariable sets an environment variable with special handling for PATH.
|
||||
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
|
||||
upperName := strings.ToUpper(name)
|
||||
|
||||
if upperName == "PATH" {
|
||||
if existing, exists := envMap["PATH"]; exists && existing != value {
|
||||
envMap["PATH"] = existing + ";" + value
|
||||
} else {
|
||||
envMap["PATH"] = value
|
||||
}
|
||||
} else {
|
||||
envMap[upperName] = value
|
||||
}
|
||||
}
|
||||
|
||||
// setUserEnvironmentVariables sets critical user-specific environment variables.
|
||||
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
|
||||
envMap["USERPROFILE"] = userProfile
|
||||
|
||||
if len(userProfile) >= 2 && userProfile[1] == ':' {
|
||||
envMap["HOMEDRIVE"] = userProfile[:2]
|
||||
envMap["HOMEPATH"] = userProfile[2:]
|
||||
}
|
||||
|
||||
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
|
||||
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
|
||||
|
||||
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
|
||||
envMap["TEMP"] = tempDir
|
||||
envMap["TMP"] = tempDir
|
||||
|
||||
envMap["USERNAME"] = username
|
||||
if domain != "" && domain != "." {
|
||||
envMap["USERDOMAIN"] = domain
|
||||
envMap["USERDNSDOMAIN"] = domain
|
||||
}
|
||||
|
||||
systemVars := []string{
|
||||
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
|
||||
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
|
||||
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
|
||||
}
|
||||
|
||||
for _, sysVar := range systemVars {
|
||||
if sysValue := os.Getenv(sysVar); sysValue != "" {
|
||||
envMap[sysVar] = sysValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
if privilegeResult.User == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
return true
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-NoLogo"}
|
||||
}
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
logger.Info("starting interactive shell")
|
||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
Username string
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
workingDir := getUserHomeFromEnv(userEnv)
|
||||
if workingDir == "" {
|
||||
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
|
||||
}
|
||||
|
||||
ptyConfig := winpty.PtyConfig{
|
||||
Shell: req.Shell,
|
||||
Command: req.Command,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
WorkingDir: workingDir,
|
||||
}
|
||||
|
||||
userConfig := winpty.UserConfig{
|
||||
Token: userToken,
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
for _, envVar := range env {
|
||||
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
|
||||
return envVar[12:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
// Windows doesn't support process groups in the same way as Unix
|
||||
// Process creation groups are handled differently
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// detectSuPtySupport always returns false on Windows as su is not available
|
||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
logger.Error("no command specified for PTY execution")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
||||
}
|
||||
|
||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
|
||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
||||
logger.Errorf("ConPty execution failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ConPty execution completed")
|
||||
return true
|
||||
}
|
||||
722
client/ssh/server/compatibility_test.go
Normal file
722
client/ssh/server/compatibility_test.go
Normal file
@@ -0,0 +1,722 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH compatibility tests in short mode")
|
||||
}
|
||||
|
||||
// Check if ssh binary is available
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Set up SSH server - use our existing key generation for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate OpenSSH-compatible keys for client
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Create temporary key files for SSH client
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
// Extract host and port from server address
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get appropriate user for SSH connection (handle system accounts)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("basic command execution", func(t *testing.T) {
|
||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||
})
|
||||
|
||||
t.Run("interactive command", func(t *testing.T) {
|
||||
testSSHInteractiveCommand(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
|
||||
t.Run("port forwarding", func(t *testing.T) {
|
||||
testSSHPortForwarding(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
}
|
||||
|
||||
// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user.
|
||||
func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) {
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "hello_world")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("SSH command failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
|
||||
}
|
||||
|
||||
// testSSHInteractiveCommand tests interactive shell session.
|
||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host))
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create stdin pipe: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create stdout pipe: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
t.Logf("Cannot start SSH session: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := stdin.Close(); err != nil {
|
||||
t.Logf("stdin close error: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil {
|
||||
t.Logf("stdin write error: %v", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := stdin.Write([]byte("exit\n")); err != nil {
|
||||
t.Logf("stdin write error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
output, err := io.ReadAll(stdout)
|
||||
if err != nil {
|
||||
t.Logf("Cannot read SSH output: %v", err)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
t.Logf("SSH interactive session error: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work")
|
||||
}
|
||||
|
||||
// testSSHPortForwarding tests port forwarding compatibility.
|
||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer testServer.Close()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("test server connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("Test server read error: %v", err)
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("Test server write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
localListener.Close()
|
||||
|
||||
_, localPort, err := net.SplitHostPort(localAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr)
|
||||
cmd := exec.CommandContext(ctx, "ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-L", forwardSpec,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-N",
|
||||
fmt.Sprintf("%s@%s", username, host))
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
t.Logf("Cannot start SSH port forwarding: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if cmd.Process != nil {
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
t.Logf("process kill error: %v", err)
|
||||
}
|
||||
}
|
||||
if err := cmd.Wait(); err != nil {
|
||||
t.Logf("process wait after kill: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
|
||||
if err != nil {
|
||||
t.Logf("Cannot connect to forwarded port: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("forwarded connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
_, err = conn.Write([]byte(request))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
|
||||
log.Debugf("failed to set read deadline: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
if err != nil {
|
||||
t.Logf("Cannot read forwarded response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding")
|
||||
}
|
||||
|
||||
// isSSHClientAvailable checks if the ssh binary is available
|
||||
func isSSHClientAvailable() bool {
|
||||
_, err := exec.LookPath("ssh")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use.
|
||||
func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) {
|
||||
// Check if ssh-keygen is available
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
// Fall back to our existing key generation and try to convert
|
||||
return generateOpenSSHKeyFallback()
|
||||
}
|
||||
|
||||
// Create temporary file for ssh-keygen
|
||||
tempFile, err := os.CreateTemp("", "ssh_keygen_*")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
keyPath := tempFile.Name()
|
||||
tempFile.Close()
|
||||
|
||||
// Remove the temp file so ssh-keygen can create it
|
||||
if err := os.Remove(keyPath); err != nil {
|
||||
t.Logf("failed to remove key file: %v", err)
|
||||
}
|
||||
|
||||
// Clean up temp files
|
||||
defer func() {
|
||||
if err := os.Remove(keyPath); err != nil {
|
||||
t.Logf("failed to cleanup key file: %v", err)
|
||||
}
|
||||
if err := os.Remove(keyPath + ".pub"); err != nil {
|
||||
t.Logf("failed to cleanup public key file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Generate key using ssh-keygen
|
||||
cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
// Read private key
|
||||
privKeyBytes, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read private key: %w", err)
|
||||
}
|
||||
|
||||
// Read public key
|
||||
pubKeyBytes, err := os.ReadFile(keyPath + ".pub")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read public key: %w", err)
|
||||
}
|
||||
|
||||
return privKeyBytes, pubKeyBytes, nil
|
||||
}
|
||||
|
||||
// generateOpenSSHKeyFallback falls back to generating keys using our existing method
|
||||
func generateOpenSSHKeyFallback() ([]byte, []byte, error) {
|
||||
// Generate shared.ED25519 key pair using our existing method
|
||||
_, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
// Convert to SSH format
|
||||
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create signer: %w", err)
|
||||
}
|
||||
|
||||
// For the fallback, just use our PKCS#8 format and hope it works
|
||||
// This won't be in OpenSSH format but might still work with some SSH clients
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate fallback key: %w", err)
|
||||
}
|
||||
|
||||
// Get public key in SSH format
|
||||
sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey())
|
||||
|
||||
return hostKey, sshPubKey, nil
|
||||
}
|
||||
|
||||
// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes
|
||||
func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
tempFile, err := os.CreateTemp("", "ssh_test_key_*")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tempFile.Write(keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tempFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set proper permissions for SSH key (readable by owner only)
|
||||
err = os.Chmod(tempFile.Name(), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.Remove(tempFile.Name())
|
||||
}
|
||||
|
||||
return tempFile.Name(), cleanup
|
||||
}
|
||||
|
||||
// createTempKeyFile creates a temporary SSH private key file (for backward compatibility)
|
||||
func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Test various SSH features
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFunc func(t *testing.T, host, port, keyFile string)
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "command_with_flags",
|
||||
testFunc: testCommandWithFlags,
|
||||
description: "Commands with flags should work like standard SSH",
|
||||
},
|
||||
{
|
||||
name: "environment_variables",
|
||||
testFunc: testEnvironmentVariables,
|
||||
description: "Environment variables should be available",
|
||||
},
|
||||
{
|
||||
name: "exit_codes",
|
||||
testFunc: testExitCodes,
|
||||
description: "Exit codes should be properly handled",
|
||||
},
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.testFunc(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testCommandWithFlags tests that commands with flags work properly
|
||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test ls with flags
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"ls", "-la", "/tmp")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Command with flags failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// Should not be empty and should not contain error messages
|
||||
assert.NotEmpty(t, string(output), "ls -la should produce output")
|
||||
assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed")
|
||||
}
|
||||
|
||||
// testEnvironmentVariables tests that environment is properly set up
|
||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "$HOME")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Environment test failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// HOME environment variable should be available
|
||||
homeOutput := strings.TrimSpace(string(output))
|
||||
assert.NotEmpty(t, homeOutput, "HOME environment variable should be set")
|
||||
assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded")
|
||||
}
|
||||
|
||||
// testExitCodes tests that exit codes are properly handled
|
||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test successful command (exit code 0)
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"true") // always succeeds
|
||||
|
||||
err := cmd.Run()
|
||||
assert.NoError(t, err, "Command with exit code 0 should succeed")
|
||||
|
||||
// Test failing command (exit code 1)
|
||||
cmd = exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"false") // always fails
|
||||
|
||||
err = cmd.Run()
|
||||
assert.Error(t, err, "Command with exit code 1 should fail")
|
||||
|
||||
// Check if it's the right kind of error
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHServerSecurityFeatures tests security-related SSH features
|
||||
func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH security tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server with specific security settings
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("key_authentication", func(t *testing.T) {
|
||||
// Test that key authentication works
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "auth_success")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Key authentication failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "auth_success", "Key authentication should work")
|
||||
})
|
||||
|
||||
t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) {
|
||||
// Create a different key that shouldn't be accepted
|
||||
wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey)
|
||||
defer cleanupWrongKey()
|
||||
|
||||
// Test that wrong key is rejected
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", wrongKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "should_not_work")
|
||||
|
||||
err = cmd.Run()
|
||||
assert.NoError(t, err, "Any key should work in no-auth mode")
|
||||
})
|
||||
}
|
||||
|
||||
// TestCrossPlatformCompatibility tests cross-platform behavior
|
||||
func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping cross-platform compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test platform-specific commands
|
||||
var testCommand string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
testCommand = "echo %OS%"
|
||||
default:
|
||||
testCommand = "uname"
|
||||
}
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
testCommand)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Platform-specific command failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
outputStr := strings.TrimSpace(string(output))
|
||||
t.Logf("Platform command output: %s", outputStr)
|
||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||
}
|
||||
253
client/ssh/server/executor_unix.go
Normal file
253
client/ssh/server/executor_unix.go
Normal file
@@ -0,0 +1,253 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Exit codes for executor process communication
|
||||
const (
|
||||
ExitCodeSuccess = 0
|
||||
ExitCodePrivilegeDropFail = 10
|
||||
ExitCodeShellExecFail = 11
|
||||
ExitCodeValidationFail = 12
|
||||
)
|
||||
|
||||
// ExecutorConfig holds configuration for the executor process
|
||||
type ExecutorConfig struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
Groups []uint32
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
PTY bool
|
||||
}
|
||||
|
||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
// NewPrivilegeDropper creates a new privilege dropper
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) {
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get netbird executable path: %w", err)
|
||||
}
|
||||
|
||||
if err := pd.validatePrivileges(config.UID, config.GID); err != nil {
|
||||
return nil, fmt.Errorf("invalid privileges: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "exec",
|
||||
"--uid", fmt.Sprintf("%d", config.UID),
|
||||
"--gid", fmt.Sprintf("%d", config.GID),
|
||||
"--working-dir", config.WorkingDir,
|
||||
"--shell", config.Shell,
|
||||
}
|
||||
|
||||
for _, group := range config.Groups {
|
||||
args = append(args, "--groups", fmt.Sprintf("%d", group))
|
||||
}
|
||||
|
||||
if config.PTY {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
|
||||
if config.Command != "" {
|
||||
args = append(args, "--cmd", config.Command)
|
||||
}
|
||||
|
||||
// Log executor args safely - show all args except hide the command value
|
||||
safeArgs := make([]string, len(args))
|
||||
copy(safeArgs, args)
|
||||
for i := 0; i < len(safeArgs)-1; i++ {
|
||||
if safeArgs[i] == "--cmd" {
|
||||
cmdParts := strings.Fields(safeArgs[i+1])
|
||||
safeArgs[i+1] = safeLogCommand(cmdParts)
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||
}
|
||||
|
||||
// DropPrivileges performs privilege dropping with thread locking for security
|
||||
func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
|
||||
if err := pd.validatePrivileges(targetUID, targetGID); err != nil {
|
||||
return fmt.Errorf("invalid privileges: %w", err)
|
||||
}
|
||||
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
originalUID := os.Geteuid()
|
||||
originalGID := os.Getegid()
|
||||
|
||||
if originalUID != int(targetUID) || originalGID != int(targetGID) {
|
||||
if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil {
|
||||
return fmt.Errorf("set groups and IDs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setGroupsAndIDs sets the supplementary groups, GID, and UID
|
||||
func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
|
||||
groups := make([]int, len(supplementaryGroups))
|
||||
for i, g := range supplementaryGroups {
|
||||
groups[i] = int(g)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" {
|
||||
if len(groups) == 0 || groups[0] != int(targetGID) {
|
||||
groups = append([]int{int(targetGID)}, groups...)
|
||||
}
|
||||
}
|
||||
|
||||
if err := syscall.Setgroups(groups); err != nil {
|
||||
return fmt.Errorf("setgroups to %v: %w", groups, err)
|
||||
}
|
||||
|
||||
if err := syscall.Setgid(int(targetGID)); err != nil {
|
||||
return fmt.Errorf("setgid to %d: %w", targetGID, err)
|
||||
}
|
||||
|
||||
if err := syscall.Setuid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setuid to %d: %w", targetUID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePrivilegeDropSuccess validates that privilege dropping was successful
|
||||
func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error {
|
||||
if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePrivilegeDropReversibility ensures privileges cannot be restored
|
||||
func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error {
|
||||
if originalGID != int(targetGID) {
|
||||
if err := syscall.Setegid(originalGID); err == nil {
|
||||
return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID)
|
||||
}
|
||||
}
|
||||
if originalUID != int(targetUID) {
|
||||
if err := syscall.Seteuid(originalUID); err == nil {
|
||||
return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCurrentPrivileges validates the current UID and GID match the target
|
||||
func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error {
|
||||
currentUID := os.Geteuid()
|
||||
if currentUID != int(targetUID) {
|
||||
return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID)
|
||||
}
|
||||
|
||||
currentGID := os.Getegid()
|
||||
if currentGID != int(targetGID) {
|
||||
return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures
|
||||
func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) {
|
||||
log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
// TODO: Implement Pty support for executor path
|
||||
if config.PTY {
|
||||
config.PTY = false
|
||||
}
|
||||
|
||||
if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err)
|
||||
os.Exit(ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
var execCmd *exec.Cmd
|
||||
if config.Command == "" {
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
execCmd.Stdin = os.Stdin
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
if err := execCmd.Run(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
// Normal command exit with non-zero code - not an SSH execution error
|
||||
log.Tracef("command exited with code %d", exitError.ExitCode())
|
||||
os.Exit(exitError.ExitCode())
|
||||
}
|
||||
|
||||
// Actual execution failure (command not found, permission denied, etc.)
|
||||
log.Debugf("command execution failed: %v", err)
|
||||
os.Exit(ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
|
||||
// validatePrivileges validates that privilege dropping to the target UID/GID is allowed
|
||||
func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error {
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
// Allow same-user operations (no privilege dropping needed)
|
||||
if uid == currentUID && gid == currentGID {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only root can drop privileges to other users
|
||||
if currentUID != 0 {
|
||||
return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid)
|
||||
}
|
||||
|
||||
// Root can drop to any user (including root itself)
|
||||
return nil
|
||||
}
|
||||
262
client/ssh/server/executor_unix_test.go
Normal file
262
client/ssh/server/executor_unix_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uid uint32
|
||||
gid uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "same user - no privilege drop needed",
|
||||
uid: currentUID,
|
||||
gid: currentGID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-root to different user should fail",
|
||||
uid: currentUID + 1, // Use a different UID to ensure it's actually different
|
||||
gid: currentGID + 1, // Use a different GID to ensure it's actually different
|
||||
wantErr: currentUID != 0, // Only fail if current user is not root
|
||||
},
|
||||
{
|
||||
name: "root can drop to any user",
|
||||
uid: 1000,
|
||||
gid: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "root can stay as root",
|
||||
uid: 0,
|
||||
gid: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Skip non-root tests when running as root, and root tests when not root
|
||||
if tt.name == "non-root to different user should fail" && currentUID == 0 {
|
||||
t.Skip("Skipping non-root test when running as root")
|
||||
}
|
||||
if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 {
|
||||
t.Skip("Skipping root test when not running as root")
|
||||
}
|
||||
|
||||
err := pd.validatePrivileges(tt.uid, tt.gid)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
|
||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||
// This test requires root privileges and will be skipped if not running as root
|
||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("This test requires root privileges")
|
||||
}
|
||||
|
||||
// Find a non-root user to test with
|
||||
testUser, err := findNonRootUser()
|
||||
if err != nil {
|
||||
t.Skip("No suitable non-root user found for testing")
|
||||
}
|
||||
|
||||
// Verify the user actually exists by looking it up again
|
||||
_, err = user.LookupId(testUser.Uid)
|
||||
if err != nil {
|
||||
t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err)
|
||||
}
|
||||
|
||||
uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
targetUID := uint32(uid64)
|
||||
|
||||
gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
targetGID := uint32(gid64)
|
||||
|
||||
// Test in a child process to avoid affecting the test runner
|
||||
if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
// This should succeed
|
||||
err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we are now running as the target user
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
assert.Equal(t, targetUID, currentUID, "UID should match target")
|
||||
assert.Equal(t, targetGID, currentGID, "GID should match target")
|
||||
assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
|
||||
assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Fork a child process to test privilege dropping
|
||||
cmd := os.Args[0]
|
||||
args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
|
||||
|
||||
env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
|
||||
|
||||
execCmd := exec.Command(cmd, args...)
|
||||
execCmd.Env = env
|
||||
|
||||
err = execCmd.Run()
|
||||
require.NoError(t, err, "Child process should succeed")
|
||||
}
|
||||
|
||||
// findNonRootUser finds any non-root user on the system for testing
|
||||
func findNonRootUser() (*user.User, error) {
|
||||
// Try common non-root users, but avoid "nobody" on macOS due to negative UID issues
|
||||
commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
|
||||
|
||||
for _, username := range commonUsers {
|
||||
if u, err := user.Lookup(username); err == nil {
|
||||
// Parse as signed integer first to handle negative UIDs
|
||||
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Skip negative UIDs (like nobody=-2 on macOS) and root
|
||||
if uid64 > 0 && uid64 != 0 {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no common users found, try to find any regular user with UID > 100
|
||||
// This helps on macOS where regular users start at UID 501
|
||||
allUsers := []string{"vma", "user", "test", "admin"}
|
||||
for _, username := range allUsers {
|
||||
if u, err := user.Lookup(username); err == nil {
|
||||
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if uid64 > 100 { // Regular user
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no common users found, return an error
|
||||
return nil, fmt.Errorf("no suitable non-root user found on this system")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
currentUID := uint32(os.Geteuid())
|
||||
|
||||
if currentUID == 0 {
|
||||
// When running as root, test that root can create commands for any user
|
||||
config := ExecutorConfig{
|
||||
UID: 1000, // Target non-root user
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/tmp",
|
||||
Shell: "/bin/sh",
|
||||
Command: "echo test",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
assert.NoError(t, err, "Root should be able to create commands for any user")
|
||||
assert.NotNil(t, cmd)
|
||||
} else {
|
||||
// When running as non-root, test that we can't drop to a different user
|
||||
config := ExecutorConfig{
|
||||
UID: 0, // Try to target root
|
||||
GID: 0,
|
||||
Groups: []uint32{0},
|
||||
WorkingDir: "/tmp",
|
||||
Shell: "/bin/sh",
|
||||
Command: "echo test",
|
||||
}
|
||||
|
||||
_, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot drop privileges")
|
||||
}
|
||||
}
|
||||
570
client/ssh/server/executor_windows.go
Normal file
570
client/ssh/server/executor_windows.go
Normal file
@@ -0,0 +1,570 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
ExitCodeSuccess = 0
|
||||
ExitCodeLogonFail = 10
|
||||
ExitCodeCreateProcessFail = 11
|
||||
ExitCodeWorkingDirFail = 12
|
||||
ExitCodeShellExecFail = 13
|
||||
ExitCodeValidationFail = 14
|
||||
)
|
||||
|
||||
type WindowsExecutorConfig struct {
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Interactive bool
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
}
|
||||
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
var (
|
||||
advapi32 = windows.NewLazyDLL("advapi32.dll")
|
||||
procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId")
|
||||
)
|
||||
|
||||
const (
|
||||
logon32LogonNetwork = 3 // Network logon - no password required for authenticated users
|
||||
|
||||
// Common error messages
|
||||
commandFlag = "-Command"
|
||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
||||
convertUsernameError = "convert username to UTF16: %w"
|
||||
convertDomainError = "convert domain to UTF16: %w"
|
||||
)
|
||||
|
||||
// CreateWindowsExecutorCommand creates a Windows command with privilege dropping.
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
|
||||
if config.Username == "" {
|
||||
return nil, 0, errors.New("username cannot be empty")
|
||||
}
|
||||
if config.Shell == "" {
|
||||
return nil, 0, errors.New("shell cannot be empty")
|
||||
}
|
||||
|
||||
shell := config.Shell
|
||||
|
||||
var shellArgs []string
|
||||
if config.Command != "" {
|
||||
shellArgs = []string{shell, commandFlag, config.Command}
|
||||
} else {
|
||||
shellArgs = []string{shell}
|
||||
}
|
||||
|
||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
|
||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create Windows process as user: %w", err)
|
||||
}
|
||||
|
||||
return cmd, token, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// StatusSuccess represents successful LSA operation
|
||||
StatusSuccess = 0
|
||||
|
||||
// KerbS4ULogonType message type for domain users with Kerberos
|
||||
KerbS4ULogonType = 12
|
||||
// Msv10s4ulogontype message type for local users with MSV1_0
|
||||
Msv10s4ulogontype = 12
|
||||
|
||||
// MicrosoftKerberosNameA is the authentication package name for Kerberos
|
||||
MicrosoftKerberosNameA = "Kerberos"
|
||||
// Msv10packagename is the authentication package name for MSV1_0
|
||||
Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"
|
||||
|
||||
NameSamCompatible = 2
|
||||
NameUserPrincipal = 8
|
||||
NameCanonical = 7
|
||||
|
||||
maxUPNLen = 1024
|
||||
)
|
||||
|
||||
// kerbS4ULogon structure for S4U authentication (domain users)
|
||||
type kerbS4ULogon struct {
|
||||
MessageType uint32
|
||||
Flags uint32
|
||||
ClientUpn unicodeString
|
||||
ClientRealm unicodeString
|
||||
}
|
||||
|
||||
// msv10s4ulogon structure for S4U authentication (local users)
|
||||
type msv10s4ulogon struct {
|
||||
MessageType uint32
|
||||
Flags uint32
|
||||
UserPrincipalName unicodeString
|
||||
DomainName unicodeString
|
||||
}
|
||||
|
||||
// unicodeString structure
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer *uint16
|
||||
}
|
||||
|
||||
// lsaString structure
|
||||
type lsaString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer *byte
|
||||
}
|
||||
|
||||
// tokenSource structure
|
||||
type tokenSource struct {
|
||||
SourceName [8]byte
|
||||
SourceIdentifier windows.LUID
|
||||
}
|
||||
|
||||
// quotaLimits structure
|
||||
type quotaLimits struct {
|
||||
PagedPoolLimit uint32
|
||||
NonPagedPoolLimit uint32
|
||||
MinimumWorkingSetSize uint32
|
||||
MaximumWorkingSetSize uint32
|
||||
PagefileLimit uint32
|
||||
TimeLimit int64
|
||||
}
|
||||
|
||||
var (
|
||||
secur32 = windows.NewLazyDLL("secur32.dll")
|
||||
procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess")
|
||||
procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage")
|
||||
procLsaLogonUser = secur32.NewProc("LsaLogonUser")
|
||||
procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer")
|
||||
procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess")
|
||||
procTranslateNameW = secur32.NewProc("TranslateNameW")
|
||||
)
|
||||
|
||||
// newLsaString creates an LsaString from a Go string
|
||||
func newLsaString(s string) lsaString {
|
||||
b := append([]byte(s), 0)
|
||||
return lsaString{
|
||||
Length: uint16(len(s)),
|
||||
MaximumLength: uint16(len(b)),
|
||||
Buffer: &b[0],
|
||||
}
|
||||
}
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
isDomainUser := !pd.isLocalUser(domain)
|
||||
|
||||
lsaHandle, err := initializeLsaConnection()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cleanupLsaConnection(lsaHandle)
|
||||
|
||||
authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
}
|
||||
|
||||
// buildUserCpn constructs the user principal name
|
||||
func buildUserCpn(username, domain string) string {
|
||||
if domain != "" && domain != "." {
|
||||
return fmt.Sprintf(`%s\%s`, domain, username)
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// initializeLsaConnection establishes connection to LSA
|
||||
func initializeLsaConnection() (windows.Handle, error) {
|
||||
|
||||
processName := newLsaString("NetBird")
|
||||
var mode uint32
|
||||
var lsaHandle windows.Handle
|
||||
ret, _, _ := procLsaRegisterLogonProcess.Call(
|
||||
uintptr(unsafe.Pointer(&processName)),
|
||||
uintptr(unsafe.Pointer(&lsaHandle)),
|
||||
uintptr(unsafe.Pointer(&mode)),
|
||||
)
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret)
|
||||
}
|
||||
|
||||
return lsaHandle, nil
|
||||
}
|
||||
|
||||
// cleanupLsaConnection closes the LSA connection
|
||||
func cleanupLsaConnection(lsaHandle windows.Handle) {
|
||||
if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess {
|
||||
log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupAuthenticationPackage finds the correct authentication package
|
||||
func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) {
|
||||
var authPackageName lsaString
|
||||
if isDomainUser {
|
||||
authPackageName = newLsaString(MicrosoftKerberosNameA)
|
||||
} else {
|
||||
authPackageName = newLsaString(Msv10packagename)
|
||||
}
|
||||
|
||||
var authPackageId uint32
|
||||
ret, _, _ := procLsaLookupAuthenticationPackage.Call(
|
||||
uintptr(lsaHandle),
|
||||
uintptr(unsafe.Pointer(&authPackageName)),
|
||||
uintptr(unsafe.Pointer(&authPackageId)),
|
||||
)
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret)
|
||||
}
|
||||
|
||||
return authPackageId, nil
|
||||
}
|
||||
|
||||
// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format)
|
||||
func lookupPrincipalName(username, domain string) (string, error) {
|
||||
samAccountName := fmt.Sprintf(`%s\%s`, domain, username)
|
||||
samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
upnBuf := make([]uint16, maxUPNLen+1)
|
||||
upnSize := uint32(len(upnBuf))
|
||||
|
||||
ret, _, _ := procTranslateNameW.Call(
|
||||
uintptr(unsafe.Pointer(samAccountNameUtf16)),
|
||||
uintptr(NameSamCompatible),
|
||||
uintptr(NameUserPrincipal),
|
||||
uintptr(unsafe.Pointer(&upnBuf[0])),
|
||||
uintptr(unsafe.Pointer(&upnSize)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
upn := windows.UTF16ToString(upnBuf[:upnSize])
|
||||
log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn)
|
||||
return upn, nil
|
||||
}
|
||||
|
||||
upnSize = uint32(len(upnBuf))
|
||||
ret, _, _ = procTranslateNameW.Call(
|
||||
uintptr(unsafe.Pointer(samAccountNameUtf16)),
|
||||
uintptr(NameSamCompatible),
|
||||
uintptr(NameCanonical),
|
||||
uintptr(unsafe.Pointer(&upnBuf[0])),
|
||||
uintptr(unsafe.Pointer(&upnSize)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
canonical := windows.UTF16ToString(upnBuf[:upnSize])
|
||||
slashIdx := strings.IndexByte(canonical, '/')
|
||||
if slashIdx > 0 {
|
||||
fqdn := canonical[:slashIdx]
|
||||
upn := fmt.Sprintf("%s@%s", username, fqdn)
|
||||
log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical)
|
||||
return upn, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName)
|
||||
return samAccountName, nil
|
||||
}
|
||||
|
||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
if isDomainUser {
|
||||
return prepareDomainS4ULogon(username, domain)
|
||||
}
|
||||
return prepareLocalS4ULogon(username)
|
||||
}
|
||||
|
||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
upn, err := lookupPrincipalName(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
|
||||
upnUtf16, err := windows.UTF16FromString(upn)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertUsernameError, err)
|
||||
}
|
||||
|
||||
structSize := unsafe.Sizeof(kerbS4ULogon{})
|
||||
upnByteSize := len(upnUtf16) * 2
|
||||
logonInfoSize := structSize + uintptr(upnByteSize)
|
||||
|
||||
buffer := make([]byte, logonInfoSize)
|
||||
logonInfo := unsafe.Pointer(&buffer[0])
|
||||
|
||||
s4uLogon := (*kerbS4ULogon)(logonInfo)
|
||||
s4uLogon.MessageType = KerbS4ULogonType
|
||||
s4uLogon.Flags = 0
|
||||
|
||||
upnOffset := structSize
|
||||
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
|
||||
copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
|
||||
|
||||
s4uLogon.ClientUpn = unicodeString{
|
||||
Length: uint16((len(upnUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(upnUtf16) * 2),
|
||||
Buffer: upnBuffer,
|
||||
}
|
||||
s4uLogon.ClientRealm = unicodeString{}
|
||||
|
||||
return logonInfo, logonInfoSize, nil
|
||||
}
|
||||
|
||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
|
||||
usernameUtf16, err := windows.UTF16FromString(username)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertUsernameError, err)
|
||||
}
|
||||
|
||||
domainUtf16, err := windows.UTF16FromString(".")
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertDomainError, err)
|
||||
}
|
||||
|
||||
structSize := unsafe.Sizeof(msv10s4ulogon{})
|
||||
usernameByteSize := len(usernameUtf16) * 2
|
||||
domainByteSize := len(domainUtf16) * 2
|
||||
logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize)
|
||||
|
||||
buffer := make([]byte, logonInfoSize)
|
||||
logonInfo := unsafe.Pointer(&buffer[0])
|
||||
|
||||
s4uLogon := (*msv10s4ulogon)(logonInfo)
|
||||
s4uLogon.MessageType = Msv10s4ulogontype
|
||||
s4uLogon.Flags = 0x0
|
||||
|
||||
usernameOffset := structSize
|
||||
usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset))
|
||||
copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16)
|
||||
|
||||
s4uLogon.UserPrincipalName = unicodeString{
|
||||
Length: uint16((len(usernameUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(usernameUtf16) * 2),
|
||||
Buffer: usernameBuffer,
|
||||
}
|
||||
|
||||
domainOffset := usernameOffset + uintptr(usernameByteSize)
|
||||
domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset))
|
||||
copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16)
|
||||
|
||||
s4uLogon.DomainName = unicodeString{
|
||||
Length: uint16((len(domainUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(domainUtf16) * 2),
|
||||
Buffer: domainBuffer,
|
||||
}
|
||||
|
||||
return logonInfo, logonInfoSize, nil
|
||||
}
|
||||
|
||||
// performS4ULogon executes the S4U logon operation
|
||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
var tokenSource tokenSource
|
||||
copy(tokenSource.SourceName[:], "netbird")
|
||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||
log.Debugf("AllocateLocallyUniqueId failed")
|
||||
}
|
||||
|
||||
originName := newLsaString("netbird")
|
||||
|
||||
var profile uintptr
|
||||
var profileSize uint32
|
||||
var logonId windows.LUID
|
||||
var token windows.Handle
|
||||
var quotas quotaLimits
|
||||
var subStatus int32
|
||||
|
||||
ret, _, _ := procLsaLogonUser.Call(
|
||||
uintptr(lsaHandle),
|
||||
uintptr(unsafe.Pointer(&originName)),
|
||||
logon32LogonNetwork,
|
||||
uintptr(authPackageId),
|
||||
uintptr(logonInfo),
|
||||
logonInfoSize,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&tokenSource)),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
uintptr(unsafe.Pointer(&profileSize)),
|
||||
uintptr(unsafe.Pointer(&logonId)),
|
||||
uintptr(unsafe.Pointer(&token)),
|
||||
uintptr(unsafe.Pointer("as)),
|
||||
uintptr(unsafe.Pointer(&subStatus)),
|
||||
)
|
||||
|
||||
if profile != 0 {
|
||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||
}
|
||||
|
||||
log.Debugf("created S4U %s token for user %s",
|
||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// createToken implements NetBird trust-based authentication using S4U
|
||||
func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) {
|
||||
fullUsername := buildUserCpn(username, domain)
|
||||
|
||||
if err := userExists(fullUsername, username, domain); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
isLocalUser := pd.isLocalUser(domain)
|
||||
|
||||
if isLocalUser {
|
||||
return pd.authenticateLocalUser(username, fullUsername)
|
||||
}
|
||||
return pd.authenticateDomainUser(username, domain, fullUsername)
|
||||
}
|
||||
|
||||
// userExists checks if the target useVerifier exists on the system
|
||||
func userExists(fullUsername, username, domain string) error {
|
||||
if _, err := lookupUser(fullUsername); err != nil {
|
||||
log.Debugf("User %s not found: %v", fullUsername, err)
|
||||
if domain != "" && domain != "." {
|
||||
_, err = lookupUser(username)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("target user %s not found: %w", fullUsername, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLocalUser determines if this is a local user vs domain user
|
||||
func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "localhost"
|
||||
}
|
||||
|
||||
return domain == "" || domain == "." ||
|
||||
strings.EqualFold(domain, hostname)
|
||||
}
|
||||
|
||||
// authenticateLocalUser handles authentication for local users
|
||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, ".")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// authenticateDomainUser handles authentication for domain users
|
||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||
}
|
||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables).
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) {
|
||||
token, err := pd.createToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("user authentication: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Debugf("close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createProcessWithToken creates process with the specified token and executable path.
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
|
||||
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
|
||||
cmd.Dir = workingDir
|
||||
|
||||
var primaryToken windows.Token
|
||||
err := windows.DuplicateTokenEx(
|
||||
sourceToken,
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityIdentification,
|
||||
windows.TokenPrimary,
|
||||
&primaryToken,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err)
|
||||
}
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Token: syscall.Token(primaryToken),
|
||||
}
|
||||
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
return nil, fmt.Errorf("su command not available on Windows")
|
||||
}
|
||||
629
client/ssh/server/jwt_test.go
Normal file
629
client/ssh/server/jwt_test.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT enforcement tests in short mode")
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
}
|
||||
t.Logf("Detected server type: %s", serverType)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
|
||||
})
|
||||
|
||||
t.Run("allows_when_disabled", func(t *testing.T) {
|
||||
serverConfigNoJWT := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
serverNoJWT := New(serverConfigNoJWT)
|
||||
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
assert.False(t, serverType.RequiresJWT())
|
||||
|
||||
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64RawURLEncode(n),
|
||||
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func base64RawURLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// generateValidJWT creates a valid JWT token for testing
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
|
||||
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
|
||||
t.Helper()
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
ctx := context.Background()
|
||||
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
|
||||
func TestJWTDetection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT detection test in short mode")
|
||||
}
|
||||
|
||||
jwksServer, _, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
assert.True(t, serverType.RequiresJWT())
|
||||
}
|
||||
|
||||
func TestJWTFailClose(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT fail-close tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenClaims jwt.MapClaims
|
||||
}{
|
||||
{
|
||||
name: "blocks_token_missing_iat",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_sub",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_iss",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_aud",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_issuer",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": "wrong-issuer",
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_audience",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": "wrong-audience",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_expired_token",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_exceeding_max_age",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(tokenString),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
assert.Error(t, err, "Authentication should fail (fail-close)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
|
||||
func TestJWTAuthentication(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT authentication tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantAuthOK bool
|
||||
setupServer func(*Server)
|
||||
testOperation func(*testing.T, *cryptossh.Client, string) error
|
||||
wantOpSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "allows_shell_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
return session.Shell()
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_invalid_token",
|
||||
token: "invalid",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_shell_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_command_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("ls")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_sftp_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
return session.RequestSubsystem("sftp")
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_sftp_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
err = session.RequestSubsystem("sftp")
|
||||
if err == nil {
|
||||
err = session.Wait()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_port_forward_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowRemotePortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_port_forward_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowLocalPortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// TODO: Skip port forwarding tests on Windows - user switching not supported
|
||||
// These features are tested on Linux/Unix platforms
|
||||
if runtime.GOOS == "windows" &&
|
||||
(tc.name == "allows_port_forward_with_jwt" ||
|
||||
tc.name == "blocks_port_forward_without_jwt") {
|
||||
t.Skip("Skipping port forwarding test on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
if tc.setupServer != nil {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
var authMethods []cryptossh.AuthMethod
|
||||
if tc.token == "valid" {
|
||||
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
}
|
||||
} else if tc.token == "invalid" {
|
||||
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(invalidToken),
|
||||
}
|
||||
}
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed")
|
||||
} else if err != nil {
|
||||
t.Logf("Connection failed as expected: %v", err)
|
||||
return
|
||||
}
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err = tc.testOperation(t, conn, serverAddr)
|
||||
if tc.wantOpSuccess {
|
||||
require.NoError(t, err, "Operation should succeed")
|
||||
} else {
|
||||
assert.Error(t, err, "Operation should fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
386
client/ssh/server/port_forwarding.go
Normal file
386
client/ssh/server/port_forwarding.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SessionKey uniquely identifies an SSH session
|
||||
type SessionKey string
|
||||
|
||||
// ConnectionKey uniquely identifies a port forwarding connection within a session
|
||||
type ConnectionKey string
|
||||
|
||||
// ForwardKey uniquely identifies a port forwarding listener
|
||||
type ForwardKey string
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
// SetAllowLocalPortForwarding configures local port forwarding
|
||||
func (s *Server) SetAllowLocalPortForwarding(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowLocalPortForwarding = allow
|
||||
}
|
||||
|
||||
// SetAllowRemotePortForwarding configures remote port forwarding
|
||||
func (s *Server) SetAllowRemotePortForwarding(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowRemotePortForwarding = allow
|
||||
}
|
||||
|
||||
// configurePortForwarding sets up port forwarding callbacks
|
||||
func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
allowLocal := s.allowLocalPortForwarding
|
||||
allowRemote := s.allowRemotePortForwarding
|
||||
|
||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
|
||||
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
|
||||
return true
|
||||
}
|
||||
|
||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if !allowRemote {
|
||||
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
|
||||
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
|
||||
}
|
||||
|
||||
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
|
||||
// Returns nil if allowed, error if denied.
|
||||
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
|
||||
}
|
||||
|
||||
username := ctx.User()
|
||||
remoteAddr := "unknown"
|
||||
if ctx.RemoteAddr() != nil {
|
||||
remoteAddr = ctx.RemoteAddr().String()
|
||||
}
|
||||
|
||||
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
|
||||
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: false,
|
||||
FeatureName: forwardType + " port forwarding",
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
|
||||
forwardType, result.User.Username, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
|
||||
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
if !s.isRemotePortForwardingAllowed() {
|
||||
logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
payload, err := s.parseTcpipForwardRequest(req)
|
||||
if err != nil {
|
||||
logger.Errorf("tcpip-forward unmarshal error: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
|
||||
logger.Warnf("tcpip-forward denied: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
sshConn, err := s.getSSHConnection(ctx)
|
||||
if err != nil {
|
||||
logger.Warnf("tcpip-forward request denied: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return s.setupDirectForward(ctx, logger, sshConn, payload)
|
||||
}
|
||||
|
||||
// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
|
||||
func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
var payload tcpipForwardMsg
|
||||
if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
|
||||
logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
if s.removeRemoteForwardListener(key) {
|
||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
|
||||
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
|
||||
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
|
||||
|
||||
defer func() {
|
||||
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("remote forward listener close error: %v", err)
|
||||
} else {
|
||||
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
|
||||
}
|
||||
}()
|
||||
|
||||
acceptChan := make(chan acceptResult, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
select {
|
||||
case acceptChan <- acceptResult{conn: conn, err: err}:
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case result := <-acceptChan:
|
||||
if result.err != nil {
|
||||
log.Debugf("remote forward accept error: %v", result.err)
|
||||
return
|
||||
}
|
||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||
case <-ctx.Done():
|
||||
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestLogger creates a logger with user and remote address context
|
||||
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
|
||||
remoteAddr := "unknown"
|
||||
username := "unknown"
|
||||
if ctx != nil {
|
||||
if ctx.RemoteAddr() != nil {
|
||||
remoteAddr = ctx.RemoteAddr().String()
|
||||
}
|
||||
username = ctx.User()
|
||||
}
|
||||
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
|
||||
}
|
||||
|
||||
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
|
||||
func (s *Server) isRemotePortForwardingAllowed() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// parseTcpipForwardRequest parses the SSH request payload
|
||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||
var payload tcpipForwardMsg
|
||||
err := cryptossh.Unmarshal(req.Payload, &payload)
|
||||
return &payload, err
|
||||
}
|
||||
|
||||
// getSSHConnection extracts SSH connection from context
|
||||
func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
|
||||
if ctx == nil {
|
||||
return nil, fmt.Errorf("no context")
|
||||
}
|
||||
sshConnValue := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConnValue == nil {
|
||||
return nil, fmt.Errorf("no SSH connection in context")
|
||||
}
|
||||
sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
|
||||
if !ok || sshConn == nil {
|
||||
return nil, fmt.Errorf("invalid SSH connection in context")
|
||||
}
|
||||
return sshConn, nil
|
||||
}
|
||||
|
||||
// setupDirectForward sets up a direct port forward
|
||||
func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
|
||||
bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
|
||||
|
||||
ln, err := net.Listen("tcp", bindAddr)
|
||||
if err != nil {
|
||||
logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
actualPort := payload.Port
|
||||
if payload.Port == 0 {
|
||||
tcpAddr := ln.Addr().(*net.TCPAddr)
|
||||
actualPort = uint32(tcpAddr.Port)
|
||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||
}
|
||||
|
||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
s.storeRemoteForwardListener(key, ln)
|
||||
|
||||
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
|
||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||
|
||||
response := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(response, actualPort)
|
||||
|
||||
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
|
||||
return true, response
|
||||
}
|
||||
|
||||
// acceptResult holds the result of a listener Accept() call
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
||||
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
||||
sessionKey := s.findSessionKeyByContext(ctx)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"session": sessionKey,
|
||||
"conn": connID,
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
|
||||
if sshConn == nil {
|
||||
logger.Debugf("remote forward: no SSH connection in context")
|
||||
return
|
||||
}
|
||||
|
||||
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.proxyForwardConnection(ctx, logger, conn, channel)
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
|
||||
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
|
||||
|
||||
payload := struct {
|
||||
ConnectedAddress string
|
||||
ConnectedPort uint32
|
||||
OriginatorAddress string
|
||||
OriginatorPort uint32
|
||||
}{
|
||||
ConnectedAddress: host,
|
||||
ConnectedPort: port,
|
||||
OriginatorAddress: remoteAddr.IP.String(),
|
||||
OriginatorPort: uint32(remoteAddr.Port),
|
||||
}
|
||||
|
||||
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open SSH channel: %w", err)
|
||||
}
|
||||
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
|
||||
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, conn); err != nil {
|
||||
logger.Debugf("copy error (conn->channel): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn, channel); err != nil {
|
||||
logger.Debugf("copy error (channel->conn): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
// First copy finished, wait for second copy or context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}
|
||||
712
client/ssh/server/server.go
Normal file
712
client/ssh/server/server.go
Normal file
@@ -0,0 +1,712 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 22
|
||||
|
||||
// InternalSSHPort is the port SSH server listens on and is redirected to
|
||||
const InternalSSHPort = 22022
|
||||
|
||||
const (
|
||||
errWriteSession = "write session error: %v"
|
||||
errExitSession = "exit session error: %v"
|
||||
|
||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||
|
||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||
DefaultJWTMaxTokenAge = 5 * 60
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
// PrivilegedUserError represents an error when privileged user login is disabled
|
||||
type PrivilegedUserError struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
func (e *PrivilegedUserError) Error() string {
|
||||
return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username)
|
||||
}
|
||||
|
||||
func (e *PrivilegedUserError) Is(target error) bool {
|
||||
return target == ErrPrivilegedUserDisabled
|
||||
}
|
||||
|
||||
// UserNotFoundError represents an error when a user cannot be found
|
||||
type UserNotFoundError struct {
|
||||
Username string
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("user %s not found", e.Username)
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Is(target error) bool {
|
||||
return target == ErrUserNotFound
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors
|
||||
func logSessionExitError(logger *log.Entry, err error) {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Warnf(errExitSession, err)
|
||||
}
|
||||
}
|
||||
|
||||
// safeLogCommand returns a safe representation of the command for logging
|
||||
func safeLogCommand(cmd []string) string {
|
||||
if len(cmd) == 0 {
|
||||
return "<interactive shell>"
|
||||
}
|
||||
if len(cmd) == 1 {
|
||||
return cmd[0]
|
||||
}
|
||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||
}
|
||||
|
||||
type sshConnectionState struct {
|
||||
hasActivePortForward bool
|
||||
username string
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
type authKey string
|
||||
|
||||
func newAuthKey(username string, remoteAddr net.Addr) authKey {
|
||||
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sessions map[SessionKey]ssh.Session
|
||||
sessionCancels map[ConnectionKey]context.CancelFunc
|
||||
sessionJWTUsers map[SessionKey]string
|
||||
pendingAuthJWT map[authKey]string
|
||||
|
||||
allowLocalPortForwarding bool
|
||||
allowRemotePortForwarding bool
|
||||
allowRootLogin bool
|
||||
allowSFTP bool
|
||||
jwtEnabled bool
|
||||
|
||||
netstackNet *netstack.Net
|
||||
|
||||
wgAddress wgaddr.Address
|
||||
|
||||
remoteForwardListeners map[ForwardKey]net.Listener
|
||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||
|
||||
jwtValidator *jwt.Validator
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
|
||||
suSupportsPty bool
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
type Config struct {
|
||||
// JWT authentication configuration. If nil, JWT authentication is disabled
|
||||
JWT *JWTConfig
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
}
|
||||
|
||||
// SessionInfo contains information about an active SSH session
|
||||
type SessionInfo struct {
|
||||
Username string
|
||||
RemoteAddress string
|
||||
Command string
|
||||
JWTUsername string
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
sessions: make(map[SessionKey]ssh.Session),
|
||||
sessionJWTUsers: make(map[SessionKey]string),
|
||||
pendingAuthJWT: make(map[authKey]string),
|
||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start runs the SSH server
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer != nil {
|
||||
return errors.New("SSH server is already running")
|
||||
}
|
||||
|
||||
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
||||
|
||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create listener: %w", err)
|
||||
}
|
||||
|
||||
sshServer, err := s.createSSHServer(ln.Addr())
|
||||
if err != nil {
|
||||
s.closeListener(ln)
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
s.sshServer = sshServer
|
||||
log.Infof("SSH server started on %s", addrDesc)
|
||||
|
||||
go func() {
|
||||
if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
||||
log.Errorf("SSH server error: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("listen on netstack: %w", err)
|
||||
}
|
||||
return ln, fmt.Sprintf("netstack %s", addr), nil
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addr)
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", tcpAddr.String())
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
return ln, addr.String(), nil
|
||||
}
|
||||
|
||||
func (s *Server) closeListener(ln net.Listener) {
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("listener close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.sshServer.Close(); err != nil {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.sessionJWTUsers)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.sshConnections)
|
||||
|
||||
for _, cancelFunc := range s.sessionCancels {
|
||||
cancelFunc()
|
||||
}
|
||||
maps.Clear(s.sessionCancels)
|
||||
|
||||
for _, listener := range s.remoteForwardListeners {
|
||||
if err := listener.Close(); err != nil {
|
||||
log.Debugf("close remote forward listener: %v", err)
|
||||
}
|
||||
}
|
||||
maps.Clear(s.remoteForwardListeners)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of the SSH server and active sessions
|
||||
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
enabled = s.sshServer != nil
|
||||
|
||||
for sessionKey, session := range s.sessions {
|
||||
cmd := "<interactive shell>"
|
||||
if len(session.Command()) > 0 {
|
||||
cmd = safeLogCommand(session.Command())
|
||||
}
|
||||
|
||||
jwtUsername := s.sessionJWTUsers[sessionKey]
|
||||
|
||||
sessions = append(sessions, SessionInfo{
|
||||
Username: session.User(),
|
||||
RemoteAddress: session.RemoteAddr().String(),
|
||||
Command: cmd,
|
||||
JWTUsername: jwtUsername,
|
||||
})
|
||||
}
|
||||
|
||||
return enabled, sessions
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = net
|
||||
}
|
||||
|
||||
// SetNetworkValidation configures network-based connection filtering
|
||||
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
s.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.jwtValidator = validator
|
||||
s.jwtExtractor = extractor
|
||||
|
||||
log.Infof("JWT validator initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
s.mu.RLock()
|
||||
jwtValidator := s.jwtValidator
|
||||
jwtConfig := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtValidator == nil {
|
||||
return nil, fmt.Errorf("JWT validator not initialized")
|
||||
}
|
||||
|
||||
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.checkTokenAge(token, jwtConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
if jwtConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
maxTokenAge := jwtConfig.MaxTokenAge
|
||||
if maxTokenAge <= 0 {
|
||||
maxTokenAge = DefaultJWTMaxTokenAge
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
s.mu.RLock()
|
||||
jwtExtractor := s.jwtExtractor
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtExtractor == nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
|
||||
}
|
||||
|
||||
userAuth, err := jwtExtractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
|
||||
}
|
||||
|
||||
if !s.hasSSHAccess(&userAuth) {
|
||||
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
|
||||
}
|
||||
|
||||
return &userAuth, nil
|
||||
}
|
||||
|
||||
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
return userAuth.UserId != ""
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
||||
s.mu.Lock()
|
||||
s.pendingAuthJWT[key] = userAuth.UserId
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if state, exists := s.sshConnections[sshConn]; exists {
|
||||
state.hasActivePortForward = true
|
||||
} else {
|
||||
s.sshConnections[sshConn] = &sshConnectionState{
|
||||
hasActivePortForward: true,
|
||||
username: username,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||
// We can't extract the SSH connection from net.Conn directly
|
||||
// Connection cleanup will happen during session cleanup or via timeout
|
||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
if ctx == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try to match by SSH connection
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConn == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Look through sessions to find one with matching connection
|
||||
for sessionKey, session := range s.sessions {
|
||||
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
||||
return sessionKey
|
||||
}
|
||||
}
|
||||
|
||||
// If no session found, this might be during early connection setup
|
||||
// Return a temporary key that we'll fix up later
|
||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||
return tempKey
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
s.mu.RLock()
|
||||
netbirdNetwork := s.wgAddress.Network
|
||||
localIP := s.wgAddress.IP
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
|
||||
return conn
|
||||
}
|
||||
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Block connections from our own IP (prevent local apps from connecting to ourselves)
|
||||
if remoteIP == localIP {
|
||||
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !netbirdNetwork.Contains(remoteIP) {
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
|
||||
return conn
|
||||
}
|
||||
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
if err := enableUserSwitching(); err != nil {
|
||||
log.Warnf("failed to enable user switching: %v", err)
|
||||
}
|
||||
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
|
||||
if s.jwtEnabled {
|
||||
serverVersion += " " + detection.JWTRequiredMarker
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: addr.String(),
|
||||
Handler: s.sessionHandler,
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": s.sftpSubsystemHandler,
|
||||
},
|
||||
HostSigners: []ssh.Signer{},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": s.directTCPIPHandler,
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": s.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
|
||||
},
|
||||
ConnCallback: s.connectionValidator,
|
||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
if s.jwtEnabled {
|
||||
server.PasswordHandler = s.passwordHandler
|
||||
}
|
||||
|
||||
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
||||
if err := server.SetOption(hostKeyPEM); err != nil {
|
||||
return nil, fmt.Errorf("set host key: %w", err)
|
||||
}
|
||||
|
||||
s.configurePortForwarding(server)
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.remoteForwardListeners[key] = ln
|
||||
}
|
||||
|
||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ln, exists := s.remoteForwardListeners[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(s.remoteForwardListeners, key)
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("remote forward listener close error: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||
var payload struct {
|
||||
Host string
|
||||
Port uint32
|
||||
OriginatorAddr string
|
||||
OriginatorPort uint32
|
||||
}
|
||||
|
||||
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
|
||||
log.Debugf("channel reject error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
allowLocal := s.allowLocalPortForwarding
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Check privilege requirements for the destination port
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
394
client/ssh/server/server_config_test.go
Normal file
394
client/ssh/server/server_config_test.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
)
|
||||
|
||||
func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
username string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "root login allowed",
|
||||
allowRoot: true,
|
||||
username: "root",
|
||||
expectError: false,
|
||||
description: "Root login should succeed when allowed",
|
||||
},
|
||||
{
|
||||
name: "root login denied",
|
||||
allowRoot: false,
|
||||
username: "root",
|
||||
expectError: true,
|
||||
description: "Root login should fail when disabled",
|
||||
},
|
||||
{
|
||||
name: "regular user login always allowed",
|
||||
allowRoot: false,
|
||||
username: "testuser",
|
||||
expectError: false,
|
||||
description: "Regular user login should work regardless of root setting",
|
||||
},
|
||||
}
|
||||
|
||||
// Add Windows Administrator tests if on Windows
|
||||
if runtime.GOOS == "windows" {
|
||||
tests = append(tests, []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
username string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Administrator login allowed",
|
||||
allowRoot: true,
|
||||
username: "Administrator",
|
||||
expectError: false,
|
||||
description: "Administrator login should succeed when allowed",
|
||||
},
|
||||
{
|
||||
name: "Administrator login denied",
|
||||
allowRoot: false,
|
||||
username: "Administrator",
|
||||
expectError: true,
|
||||
description: "Administrator login should fail when disabled",
|
||||
},
|
||||
{
|
||||
name: "administrator login denied (lowercase)",
|
||||
allowRoot: false,
|
||||
username: "administrator",
|
||||
expectError: true,
|
||||
description: "administrator login should fail when disabled (case insensitive)",
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock privileged environment to test root access controls
|
||||
// Set up mock users based on platform
|
||||
mockUsers := map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
"testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"),
|
||||
}
|
||||
|
||||
// Add Windows-specific users for Administrator tests
|
||||
if runtime.GOOS == "windows" {
|
||||
mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator")
|
||||
mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator")
|
||||
}
|
||||
|
||||
cleanup := setupTestDependencies(
|
||||
createTestUser("root", "0", "0", "/root"), // Running as root
|
||||
nil,
|
||||
runtime.GOOS,
|
||||
0, // euid 0 (root)
|
||||
mockUsers,
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
// Create server with specific configuration
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(tt.allowRoot)
|
||||
|
||||
// Test the userNameLookup method directly
|
||||
user, err := server.userNameLookup(tt.username)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
|
||||
// Check for appropriate error message based on platform capabilities
|
||||
errorMsg := err.Error()
|
||||
// Either privileged user restriction OR user switching limitation
|
||||
hasPrivilegedError := strings.Contains(errorMsg, "privileged user")
|
||||
hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported")
|
||||
assert.True(t, hasPrivilegedError || hasSwitchingError,
|
||||
"Expected privileged user or user switching error, got: %s", errorMsg)
|
||||
}
|
||||
} else {
|
||||
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
|
||||
// For privileged users, we expect either success or a different error
|
||||
// (like user not found), but not the "login disabled" error
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "privileged user login is disabled")
|
||||
}
|
||||
} else {
|
||||
// For regular users, lookup should generally succeed or fall back gracefully
|
||||
// Note: may return current user as fallback
|
||||
assert.NotNil(t, user)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortForwardingRestriction(t *testing.T) {
|
||||
// Test that the port forwarding callbacks properly respect configuration flags
|
||||
// This is a unit test of the callback logic, not a full integration test
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowLocalForwarding bool
|
||||
allowRemoteForwarding bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "all forwarding allowed",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: true,
|
||||
description: "Both local and remote forwarding should be allowed",
|
||||
},
|
||||
{
|
||||
name: "local forwarding disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: true,
|
||||
description: "Local forwarding should be denied when disabled",
|
||||
},
|
||||
{
|
||||
name: "remote forwarding disabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: false,
|
||||
description: "Remote forwarding should be denied when disabled",
|
||||
},
|
||||
{
|
||||
name: "all forwarding disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: false,
|
||||
description: "Both forwarding types should be denied when disabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create server with specific configuration
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||
|
||||
// We need to access the internal configuration to simulate the callback tests
|
||||
// Since the callbacks are created inside the Start method, we'll test the logic directly
|
||||
|
||||
// Test the configuration values are set correctly
|
||||
server.mu.RLock()
|
||||
allowLocal := server.allowLocalPortForwarding
|
||||
allowRemote := server.allowRemotePortForwarding
|
||||
server.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly")
|
||||
assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly")
|
||||
|
||||
// Simulate the callback logic
|
||||
localResult := allowLocal // This would be the callback return value
|
||||
remoteResult := allowRemote // This would be the callback return value
|
||||
|
||||
assert.Equal(t, tt.allowLocalForwarding, localResult,
|
||||
"Local port forwarding callback should return correct value")
|
||||
assert.Equal(t, tt.allowRemoteForwarding, remoteResult,
|
||||
"Remote port forwarding callback should return correct value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortConflictHandling(t *testing.T) {
|
||||
// Test that multiple sessions requesting the same local port are handled naturally by the OS
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Get a free port for testing
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
testPort := ln.Addr().(*net.TCPAddr).Port
|
||||
err = ln.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect first client
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client1.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Connect second client
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client2.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// First client binds to the test port
|
||||
localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort)
|
||||
remoteAddr := "127.0.0.1:80"
|
||||
|
||||
// Start first client's port forwarding
|
||||
done1 := make(chan error, 1)
|
||||
go func() {
|
||||
// This should succeed and hold the port
|
||||
err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr)
|
||||
done1 <- err
|
||||
}()
|
||||
|
||||
// Give first client time to bind
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Second client tries to bind to same port
|
||||
localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort)
|
||||
|
||||
shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer shortCancel()
|
||||
|
||||
err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr)
|
||||
// Second client should fail due to "address already in use"
|
||||
assert.Error(t, err, "Second client should fail to bind to same port")
|
||||
if err != nil {
|
||||
// The error should indicate the address is already in use
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, errMsg, "only one usage of each socket address",
|
||||
"Error should indicate port conflict")
|
||||
} else {
|
||||
assert.Contains(t, errMsg, "address already in use",
|
||||
"Error should indicate port conflict")
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel first client's context and wait for it to finish
|
||||
cancel1()
|
||||
select {
|
||||
case err1 := <-done1:
|
||||
// Should get context cancelled or deadline exceeded
|
||||
assert.Error(t, err1, "First client should exit when context cancelled")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("First client did not exit within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_IsPrivilegedUser(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "root",
|
||||
expected: true,
|
||||
description: "root should be considered privileged",
|
||||
},
|
||||
{
|
||||
username: "regular",
|
||||
expected: false,
|
||||
description: "regular user should not be privileged",
|
||||
},
|
||||
{
|
||||
username: "",
|
||||
expected: false,
|
||||
description: "empty username should not be privileged",
|
||||
},
|
||||
}
|
||||
|
||||
// Add Windows-specific tests
|
||||
if runtime.GOOS == "windows" {
|
||||
tests = append(tests, []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "Administrator",
|
||||
expected: true,
|
||||
description: "Administrator should be considered privileged on Windows",
|
||||
},
|
||||
{
|
||||
username: "administrator",
|
||||
expected: true,
|
||||
description: "administrator should be considered privileged on Windows (case insensitive)",
|
||||
},
|
||||
}...)
|
||||
} else {
|
||||
// On non-Windows systems, Administrator should not be privileged
|
||||
tests = append(tests, []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "Administrator",
|
||||
expected: false,
|
||||
description: "Administrator should not be privileged on non-Windows systems",
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
result := isPrivilegedUsername(tt.username)
|
||||
assert.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
441
client/ssh/server/server_test.go
Normal file
441
client/ssh/server/server_test.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: key,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSSHServerIntegration(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with random port
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server in background
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Get a free port
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key for verification
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Create SSH client config
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// Connect to SSH server
|
||||
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("close client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test creating a session
|
||||
session, err := client.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
t.Logf("close session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Note: Since we don't have a real shell environment in tests,
|
||||
// we can't test actual command execution, but we can verify
|
||||
// the connection and authentication work
|
||||
t.Log("SSH connection and authentication successful")
|
||||
}
|
||||
|
||||
func TestSSHServerMultipleConnections(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// Test multiple concurrent connections
|
||||
const numConnections = 5
|
||||
results := make(chan error, numConnections)
|
||||
|
||||
for i := 0; i < numConnections; i++ {
|
||||
go func(id int) {
|
||||
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("connection %d failed: %w", id, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close() // Ignore error in test goroutine
|
||||
}()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("session %d failed: %w", id, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = session.Close() // Ignore error in test goroutine
|
||||
}()
|
||||
|
||||
results <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for i := 0; i < numConnections; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("Connection %d timed out", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Try to connect with client key
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(clientSigner),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||
if conn != nil {
|
||||
assert.NoError(t, conn.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServerStartStopCycle(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
// Test multiple start/stop cycles
|
||||
for i := 0; i < 3; i++ {
|
||||
t.Logf("Start/stop cycle %d", i+1)
|
||||
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("Cycle %d: Server start timeout", i+1)
|
||||
}
|
||||
|
||||
err = server.Stop()
|
||||
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Windows shell test in short mode")
|
||||
}
|
||||
|
||||
server := &Server{}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// Test Windows cmd.exe shell behavior
|
||||
args := server.getShellCommandArgs("cmd.exe", "echo test")
|
||||
assert.Equal(t, "cmd.exe", args[0])
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
|
||||
// Test PowerShell behavior
|
||||
args = server.getShellCommandArgs("powershell.exe", "echo test")
|
||||
assert.Equal(t, "powershell.exe", args[0])
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
} else {
|
||||
// Test Unix shell behavior
|
||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Equal(t, "-l", args[1])
|
||||
assert.Equal(t, "-c", args[2])
|
||||
assert.Equal(t, "echo test", args[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig1 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server1 := New(serverConfig1)
|
||||
|
||||
serverConfig2 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server2 := New(serverConfig2)
|
||||
|
||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||
|
||||
server2.SetAllowLocalPortForwarding(true)
|
||||
server2.SetAllowRemotePortForwarding(true)
|
||||
|
||||
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
|
||||
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
|
||||
}
|
||||
168
client/ssh/server/session_handlers.go
Normal file
168
client/ssh/server/session_handlers.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// sessionHandler handles SSH sessions
|
||||
func (s *Server) sessionHandler(session ssh.Session) {
|
||||
sessionKey := s.registerSession(session)
|
||||
|
||||
key := newAuthKey(session.User(), session.RemoteAddr())
|
||||
s.mu.Lock()
|
||||
jwtUsername := s.pendingAuthJWT[key]
|
||||
if jwtUsername != "" {
|
||||
s.sessionJWTUsers[sessionKey] = jwtUsername
|
||||
delete(s.pendingAuthJWT, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
logger := log.WithField("session", sessionKey)
|
||||
if jwtUsername != "" {
|
||||
logger = logger.WithField("jwt_user", jwtUsername)
|
||||
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
|
||||
} else {
|
||||
logger.Infof("SSH session started")
|
||||
}
|
||||
sessionStart := time.Now()
|
||||
|
||||
defer s.unregisterSession(sessionKey, session)
|
||||
defer func() {
|
||||
duration := time.Since(sessionStart).Round(time.Millisecond)
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Warnf("close session after %v: %v", duration, err)
|
||||
}
|
||||
logger.Infof("SSH session closed after %v", duration)
|
||||
}()
|
||||
|
||||
privilegeResult, err := s.userPrivilegeCheck(session.User())
|
||||
if err != nil {
|
||||
s.handlePrivError(logger, session, err)
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
switch {
|
||||
case isPty && hasCommand:
|
||||
// ssh -t <host> <cmd> - Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
||||
case isPty:
|
||||
// ssh <host> - Pty interactive session (login)
|
||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
||||
case hasCommand:
|
||||
// ssh <host> <cmd> - non-Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, nil)
|
||||
default:
|
||||
s.rejectInvalidSession(logger, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
|
||||
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
|
||||
}
|
||||
|
||||
func (s *Server) registerSession(session ssh.Session) SessionKey {
|
||||
sessionID := session.Context().Value(ssh.ContextKeySessionID)
|
||||
if sessionID == nil {
|
||||
sessionID = fmt.Sprintf("%p", session)
|
||||
}
|
||||
|
||||
// Create a short 4-byte identifier from the full session ID
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
|
||||
hash := hasher.Sum(nil)
|
||||
shortID := hex.EncodeToString(hash[:4])
|
||||
|
||||
remoteAddr := session.RemoteAddr().String()
|
||||
username := session.User()
|
||||
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[sessionKey] = session
|
||||
s.mu.Unlock()
|
||||
|
||||
return sessionKey
|
||||
}
|
||||
|
||||
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, sessionKey)
|
||||
delete(s.sessionJWTUsers, sessionKey)
|
||||
|
||||
// Cancel all port forwarding connections for this session
|
||||
var connectionsToCancel []ConnectionKey
|
||||
for key := range s.sessionCancels {
|
||||
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
|
||||
connectionsToCancel = append(connectionsToCancel, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range connectionsToCancel {
|
||||
if cancelFunc, exists := s.sessionCancels[key]; exists {
|
||||
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
|
||||
cancelFunc()
|
||||
delete(s.sessionCancels, key)
|
||||
}
|
||||
}
|
||||
|
||||
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
|
||||
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
|
||||
delete(s.sshConnections, sshConn)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
|
||||
logger.Warnf("user privilege check failed: %v", err)
|
||||
|
||||
errorMsg := s.buildUserLookupErrorMessage(err)
|
||||
|
||||
if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if exitErr := session.Exit(1); exitErr != nil {
|
||||
logSessionExitError(logger, exitErr)
|
||||
}
|
||||
}
|
||||
|
||||
// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
|
||||
func (s *Server) buildUserLookupErrorMessage(err error) string {
|
||||
var privilegedErr *PrivilegedUserError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &privilegedErr):
|
||||
if privilegedErr.Username == "root" {
|
||||
return "root login is disabled on this SSH server\n"
|
||||
}
|
||||
return "privileged user access is disabled on this SSH server\n"
|
||||
|
||||
case errors.Is(err, ErrPrivilegeRequired):
|
||||
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
|
||||
|
||||
case errors.Is(err, ErrPrivilegedUserSwitch):
|
||||
return "Cannot switch to privileged user - current user lacks required privileges\n"
|
||||
|
||||
default:
|
||||
return "User authentication failed\n"
|
||||
}
|
||||
}
|
||||
22
client/ssh/server/session_handlers_js.go
Normal file
22
client/ssh/server/session_handlers_js.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handlePty is not supported on JS/WASM
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
81
client/ssh/server/sftp.go
Normal file
81
client/ssh/server/sftp.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetAllowSFTP enables or disables SFTP support
|
||||
func (s *Server) SetAllowSFTP(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowSFTP = allow
|
||||
}
|
||||
|
||||
// sftpSubsystemHandler handles SFTP subsystem requests
|
||||
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
|
||||
s.mu.RLock()
|
||||
allowSFTP := s.allowSFTP
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowSFTP {
|
||||
log.Debugf("SFTP subsystem request denied: SFTP disabled")
|
||||
if err := sess.Exit(1); err != nil {
|
||||
log.Debugf("SFTP session exit failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: sess.User(),
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSFTP,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
|
||||
if err := sess.Exit(1); err != nil {
|
||||
log.Debugf("exit SFTP session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
|
||||
|
||||
if !result.RequiresUserSwitching {
|
||||
if err := s.executeSftpDirect(sess); err != nil {
|
||||
log.Errorf("SFTP direct execution: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
|
||||
log.Errorf("SFTP privilege drop execution: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// executeSftpDirect executes SFTP directly without privilege dropping
|
||||
func (s *Server) executeSftpDirect(sess ssh.Session) error {
|
||||
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
|
||||
|
||||
sftpServer, err := sftp.NewServer(sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SFTP server creation: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("failed to close sftp server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := sftpServer.Serve(); err != nil && err != io.EOF {
|
||||
return fmt.Errorf("serve: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
12
client/ssh/server/sftp_js.go
Normal file
12
client/ssh/server/sftp_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
)
|
||||
|
||||
// parseUserCredentials is not supported on JS/WASM
|
||||
func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) {
|
||||
return 0, 0, nil, errNotSupported
|
||||
}
|
||||
228
client/ssh/server/sftp_test.go
Normal file
228
client/ssh/server/sftp_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
||||
// Skip SFTP test when running as root due to protocol issues in some environments
|
||||
if os.Geteuid() == 0 {
|
||||
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
|
||||
}
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP enabled
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(true)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// (currentUser already obtained at function start)
|
||||
|
||||
// Create SSH client connection
|
||||
clientConfig := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
||||
require.NoError(t, err, "SSH connection should succeed")
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create SFTP client
|
||||
sftpClient, err := sftp.NewClient(conn)
|
||||
require.NoError(t, err, "SFTP client creation should succeed")
|
||||
defer func() {
|
||||
if err := sftpClient.Close(); err != nil {
|
||||
t.Logf("SFTP client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test basic SFTP operations
|
||||
workingDir, err := sftpClient.Getwd()
|
||||
assert.NoError(t, err, "Should be able to get working directory")
|
||||
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
|
||||
|
||||
// Test directory listing
|
||||
files, err := sftpClient.ReadDir(".")
|
||||
assert.NoError(t, err, "Should be able to list current directory")
|
||||
assert.NotNil(t, files, "File list should not be nil")
|
||||
}
|
||||
|
||||
func TestSSHServer_SFTPDisabled(t *testing.T) {
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP disabled
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(false)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// (currentUser already obtained at function start)
|
||||
|
||||
// Create SSH client connection
|
||||
clientConfig := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
||||
require.NoError(t, err, "SSH connection should succeed")
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Try to create SFTP client - should fail when SFTP is disabled
|
||||
_, err = sftp.NewClient(conn)
|
||||
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
|
||||
}
|
||||
71
client/ssh/server/sftp_unix.go
Normal file
71
client/ssh/server/sftp_unix.go
Normal file
@@ -0,0 +1,71 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping
|
||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||
uid, gid, groups, err := s.parseUserCredentials(targetUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
|
||||
sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create executor: %w", err)
|
||||
}
|
||||
|
||||
sftpCmd.Stdin = sess
|
||||
sftpCmd.Stdout = sess
|
||||
sftpCmd.Stderr = sess.Stderr()
|
||||
|
||||
log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid)
|
||||
|
||||
if err := sftpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting SFTP executor: %w", err)
|
||||
}
|
||||
|
||||
if err := sftpCmd.Wait(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
log.Tracef("SFTP process exited with code %d", exitError.ExitCode())
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping
|
||||
func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) {
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "sftp",
|
||||
"--uid", strconv.FormatUint(uint64(uid), 10),
|
||||
"--gid", strconv.FormatUint(uint64(gid), 10),
|
||||
"--working-dir", workingDir,
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
args = append(args, "--groups", strconv.FormatUint(uint64(group), 10))
|
||||
}
|
||||
|
||||
log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args)
|
||||
return exec.CommandContext(sess.Context(), netbirdPath, args...), nil
|
||||
}
|
||||
91
client/ssh/server/sftp_windows.go
Normal file
91
client/ssh/server/sftp_windows.go
Normal file
@@ -0,0 +1,91 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// createSftpCommand creates a Windows SFTP command with user switching.
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
|
||||
username, domain := s.parseUsername(targetUser.Username)
|
||||
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("get netbird executable path: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "sftp",
|
||||
"--working-dir", targetUser.HomeDir,
|
||||
"--windows-username", username,
|
||||
"--windows-domain", domain,
|
||||
}
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
token, err := pd.createToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create token: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Warnf("failed to close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create SFTP command: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
|
||||
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sftpCmd.Stdin = sess
|
||||
sftpCmd.Stdout = sess
|
||||
sftpCmd.Stderr = sess.Stderr()
|
||||
|
||||
if err := sftpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting sftp executor: %w", err)
|
||||
}
|
||||
|
||||
if err := sftpCmd.Wait(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
log.Tracef("sftp process exited with code %d", exitError.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("exec sftp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
|
||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||
sftpCmd, token, err := s.createSftpCommand(targetUser, sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sftp: %w", err)
|
||||
}
|
||||
return s.executeSftpCommand(sess, sftpCmd, token)
|
||||
}
|
||||
180
client/ssh/server/shell.go
Normal file
180
client/ssh/server/shell.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUnixShell = "/bin/sh"
|
||||
|
||||
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
|
||||
powershellExe = "powershell.exe"
|
||||
)
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func getUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
default:
|
||||
return getUnixUserShell(userID)
|
||||
}
|
||||
}
|
||||
|
||||
// getWindowsUserShell returns the best shell for Windows users.
|
||||
// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
|
||||
// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
|
||||
// PowerShell provides safer argument handling and is available on all modern Windows systems.
|
||||
// Order: pwsh.exe -> powershell.exe
|
||||
func getWindowsUserShell() string {
|
||||
if path, err := exec.LookPath(pwshExe); err == nil {
|
||||
return path
|
||||
}
|
||||
if path, err := exec.LookPath(powershellExe); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
|
||||
}
|
||||
|
||||
// getUnixUserShell returns the shell for Unix-like systems
|
||||
func getUnixUserShell(userID string) string {
|
||||
shell := getShellFromPasswd(userID)
|
||||
if shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := os.Getenv("SHELL"); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
return defaultUnixShell
|
||||
}
|
||||
|
||||
// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
|
||||
func getShellFromPasswd(userID string) string {
|
||||
file, err := os.Open("/etc/passwd")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Warnf("close /etc/passwd file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) < 7 {
|
||||
continue
|
||||
}
|
||||
|
||||
// field 2 is UID
|
||||
if fields[2] == userID {
|
||||
shell := strings.TrimSpace(fields[6])
|
||||
return shell
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Warnf("error reading /etc/passwd: %v", err)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
||||
}
|
||||
|
||||
return []string{
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("LOGNAME=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
"PATH=" + pathValue,
|
||||
}
|
||||
}
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func acceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
}
|
||||
|
||||
exactMatches := []string{
|
||||
"LANG",
|
||||
"LANGUAGE",
|
||||
"TERM",
|
||||
"COLORTERM",
|
||||
"EDITOR",
|
||||
"VISUAL",
|
||||
"PAGER",
|
||||
"LESS",
|
||||
"LESSCHARSET",
|
||||
"TZ",
|
||||
}
|
||||
|
||||
prefixMatches := []string{
|
||||
"LC_",
|
||||
}
|
||||
|
||||
for _, exact := range exactMatches {
|
||||
if varName == exact {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range prefixMatches {
|
||||
if strings.HasPrefix(varName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
45
client/ssh/server/test.go
Normal file
45
client/ssh/server/test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func StartTestServer(t *testing.T, server *Server) string {
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
return actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
411
client/ssh/server/user_utils.go
Normal file
411
client/ssh/server/user_utils.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges")
|
||||
ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges")
|
||||
)
|
||||
|
||||
// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.)
|
||||
func isPlatformUnix() bool {
|
||||
return getCurrentOS() != "windows"
|
||||
}
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = user.Current
|
||||
lookupUser = user.Lookup
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
getEuid = os.Geteuid
|
||||
)
|
||||
|
||||
const (
|
||||
// FeatureSSHLogin represents SSH login operations for privilege checking
|
||||
FeatureSSHLogin = "SSH login"
|
||||
// FeatureSFTP represents SFTP operations for privilege checking
|
||||
FeatureSFTP = "SFTP"
|
||||
)
|
||||
|
||||
// PrivilegeCheckRequest represents a privilege check request
|
||||
type PrivilegeCheckRequest struct {
|
||||
// Username being requested (empty = current user)
|
||||
RequestedUsername string
|
||||
FeatureSupportsUserSwitch bool // Does this feature/operation support user switching?
|
||||
FeatureName string
|
||||
}
|
||||
|
||||
// PrivilegeCheckResult represents the result of a privilege check
|
||||
type PrivilegeCheckResult struct {
|
||||
// Allowed indicates whether the privilege check passed
|
||||
Allowed bool
|
||||
// User is the effective user to use for the operation (nil if not allowed)
|
||||
User *user.User
|
||||
// Error contains the reason for denial (nil if allowed)
|
||||
Error error
|
||||
// UsedFallback indicates we fell back to current user instead of requested user.
|
||||
// This happens on Unix when running as an unprivileged user (e.g., in containers)
|
||||
// where there's no point in user switching since we lack privileges anyway.
|
||||
// When true, all privilege checks have already been performed and no additional
|
||||
// privilege dropping or root checks are needed - the current user is the target.
|
||||
UsedFallback bool
|
||||
// RequiresUserSwitching indicates whether user switching will actually occur
|
||||
// (false for fallback cases where no actual switching happens)
|
||||
RequiresUserSwitching bool
|
||||
}
|
||||
|
||||
// CheckPrivileges performs comprehensive privilege checking for all SSH features.
|
||||
// This is the single source of truth for privilege decisions across the SSH server.
|
||||
func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult {
|
||||
context, err := s.buildPrivilegeCheckContext(req.FeatureName)
|
||||
if err != nil {
|
||||
return PrivilegeCheckResult{Allowed: false, Error: err}
|
||||
}
|
||||
|
||||
// Handle empty username case - but still check root access controls
|
||||
if req.RequestedUsername == "" {
|
||||
if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: &PrivilegedUserError{Username: context.currentUser.Username},
|
||||
}
|
||||
}
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: context.currentUser,
|
||||
RequiresUserSwitching: false,
|
||||
}
|
||||
}
|
||||
|
||||
return s.checkUserRequest(context, req)
|
||||
}
|
||||
|
||||
// buildPrivilegeCheckContext gathers all the context needed for privilege checking
|
||||
func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) {
|
||||
currentUser, err := getCurrentUser()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current user for %s: %w", featureName, err)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
allowRoot := s.allowRootLogin
|
||||
s.mu.RUnlock()
|
||||
|
||||
return &privilegeCheckContext{
|
||||
currentUser: currentUser,
|
||||
currentUserPrivileged: getIsProcessPrivileged(),
|
||||
allowRoot: allowRoot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkUserRequest handles normal privilege checking flow for specific usernames
|
||||
func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult {
|
||||
if !ctx.currentUserPrivileged && isPlatformUnix() {
|
||||
log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)",
|
||||
ctx.currentUser.Username, req.FeatureName, req.RequestedUsername)
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: ctx.currentUser,
|
||||
UsedFallback: true,
|
||||
RequiresUserSwitching: false,
|
||||
}
|
||||
}
|
||||
|
||||
resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername)
|
||||
if err != nil {
|
||||
// Calculate if user switching would be required even if lookup failed
|
||||
needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username)
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: err,
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser)
|
||||
|
||||
if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: &PrivilegedUserError{Username: resolvedUser.Username},
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
if needsUserSwitching && !req.FeatureSupportsUserSwitch {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName),
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: resolvedUser,
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRequestedUser resolves a username to its canonical user identity
|
||||
func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) {
|
||||
if requestedUsername == "" {
|
||||
return getCurrentUser()
|
||||
}
|
||||
|
||||
if err := validateUsername(requestedUsername); err != nil {
|
||||
return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err)
|
||||
}
|
||||
|
||||
u, err := lookupUser(requestedUsername)
|
||||
if err != nil {
|
||||
return nil, &UserNotFoundError{Username: requestedUsername, Cause: err}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// isSameResolvedUser compares two resolved user identities
|
||||
func isSameResolvedUser(user1, user2 *user.User) bool {
|
||||
if user1 == nil || user2 == nil {
|
||||
return user1 == user2
|
||||
}
|
||||
return user1.Uid == user2.Uid
|
||||
}
|
||||
|
||||
// privilegeCheckContext holds all context needed for privilege checking
|
||||
type privilegeCheckContext struct {
|
||||
currentUser *user.User
|
||||
currentUserPrivileged bool
|
||||
allowRoot bool
|
||||
}
|
||||
|
||||
// isSameUser checks if two usernames refer to the same user
|
||||
// SECURITY: This function must be conservative - it should only return true
|
||||
// when we're certain both usernames refer to the exact same user identity
|
||||
func isSameUser(requestedUsername, currentUsername string) bool {
|
||||
// Empty requested username means current user
|
||||
if requestedUsername == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Exact match (most common case)
|
||||
if getCurrentOS() == "windows" {
|
||||
if strings.EqualFold(requestedUsername, currentUsername) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if requestedUsername == currentUsername {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Windows domain resolution: only allow domain stripping when comparing
|
||||
// a bare username against the current user's domain-qualified name
|
||||
if getCurrentOS() == "windows" {
|
||||
return isWindowsSameUser(requestedUsername, currentUsername)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isWindowsSameUser handles Windows-specific user comparison with domain logic
|
||||
func isWindowsSameUser(requestedUsername, currentUsername string) bool {
|
||||
// Extract domain and username parts
|
||||
extractParts := func(name string) (domain, user string) {
|
||||
// Handle DOMAIN\username format
|
||||
if idx := strings.LastIndex(name, `\`); idx != -1 {
|
||||
return name[:idx], name[idx+1:]
|
||||
}
|
||||
// Handle user@domain.com format
|
||||
if idx := strings.Index(name, "@"); idx != -1 {
|
||||
return name[idx+1:], name[:idx]
|
||||
}
|
||||
// No domain specified - local machine
|
||||
return "", name
|
||||
}
|
||||
|
||||
reqDomain, reqUser := extractParts(requestedUsername)
|
||||
curDomain, curUser := extractParts(currentUsername)
|
||||
|
||||
// Case-insensitive username comparison
|
||||
if !strings.EqualFold(reqUser, curUser) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If requested username has no domain, it refers to local machine user
|
||||
// Allow this to match the current user regardless of current user's domain
|
||||
if reqDomain == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If both have domains, they must match exactly (case-insensitive)
|
||||
return strings.EqualFold(reqDomain, curDomain)
|
||||
}
|
||||
|
||||
// SetAllowRootLogin configures root login access
|
||||
func (s *Server) SetAllowRootLogin(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowRootLogin = allow
|
||||
}
|
||||
|
||||
// userNameLookup performs user lookup with root login permission check
|
||||
func (s *Server) userNameLookup(username string) (*user.User, error) {
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSSHLogin,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return result.User, nil
|
||||
}
|
||||
|
||||
// userPrivilegeCheck performs user lookup with full privilege check result
|
||||
func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) {
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSSHLogin,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return result, result.Error
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// isPrivilegedUsername checks if the given username represents a privileged user across platforms.
|
||||
// On Unix: root
|
||||
// On Windows: Administrator, SYSTEM (case-insensitive)
|
||||
// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com"
|
||||
func isPrivilegedUsername(username string) bool {
|
||||
if getCurrentOS() != "windows" {
|
||||
return username == "root"
|
||||
}
|
||||
|
||||
bareUsername := username
|
||||
// Handle Windows domain format: DOMAIN\username
|
||||
if idx := strings.LastIndex(username, `\`); idx != -1 {
|
||||
bareUsername = username[idx+1:]
|
||||
}
|
||||
// Handle email-style format: username@domain.com
|
||||
if idx := strings.Index(bareUsername, "@"); idx != -1 {
|
||||
bareUsername = bareUsername[:idx]
|
||||
}
|
||||
|
||||
return isWindowsPrivilegedUser(bareUsername)
|
||||
}
|
||||
|
||||
// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account
|
||||
func isWindowsPrivilegedUser(bareUsername string) bool {
|
||||
// common privileged usernames (case insensitive)
|
||||
privilegedNames := []string{
|
||||
"administrator",
|
||||
"admin",
|
||||
"root",
|
||||
"system",
|
||||
"localsystem",
|
||||
"networkservice",
|
||||
"localservice",
|
||||
}
|
||||
|
||||
usernameLower := strings.ToLower(bareUsername)
|
||||
for _, privilegedName := range privilegedNames {
|
||||
if usernameLower == privilegedName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// computer accounts (ending with $) are not privileged by themselves
|
||||
// They only gain privileges through group membership or specific SIDs
|
||||
|
||||
if targetUser, err := lookupUser(bareUsername); err == nil {
|
||||
return isWindowsPrivilegedSID(targetUser.Uid)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account
|
||||
func isWindowsPrivilegedSID(sid string) bool {
|
||||
privilegedSIDs := []string{
|
||||
"S-1-5-18", // Local System (SYSTEM)
|
||||
"S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE)
|
||||
"S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE)
|
||||
"S-1-5-32-544", // Administrators group (BUILTIN\Administrators)
|
||||
"S-1-5-500", // Built-in Administrator account (local machine RID 500)
|
||||
}
|
||||
|
||||
for _, privilegedSID := range privilegedSIDs {
|
||||
if sid == privilegedSID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for domain administrator accounts (RID 500 in any domain)
|
||||
// Format: S-1-5-21-domain-domain-domain-500
|
||||
// This is reliable as RID 500 is reserved for the domain Administrator account
|
||||
if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for other well-known privileged RIDs in domain contexts
|
||||
// RID 512 = Domain Admins group, RID 516 = Domain Controllers group
|
||||
if strings.HasPrefix(sid, "S-1-5-21-") {
|
||||
if strings.HasSuffix(sid, "-512") || // Domain Admins group
|
||||
strings.HasSuffix(sid, "-516") || // Domain Controllers group
|
||||
strings.HasSuffix(sid, "-519") { // Enterprise Admins group
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isCurrentProcessPrivileged checks if the current process is running with elevated privileges.
|
||||
// On Unix systems, this means running as root (UID 0).
|
||||
// On Windows, this means running as Administrator or SYSTEM.
|
||||
func isCurrentProcessPrivileged() bool {
|
||||
if getCurrentOS() == "windows" {
|
||||
return isWindowsElevated()
|
||||
}
|
||||
return getEuid() == 0
|
||||
}
|
||||
|
||||
// isWindowsElevated checks if the current process is running with elevated privileges on Windows
|
||||
func isWindowsElevated() bool {
|
||||
currentUser, err := getCurrentUser()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if isWindowsPrivilegedSID(currentUser.Uid) {
|
||||
log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid)
|
||||
return true
|
||||
}
|
||||
|
||||
if isPrivilegedUsername(currentUser.Username) {
|
||||
log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
8
client/ssh/server/user_utils_js.go
Normal file
8
client/ssh/server/user_utils_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
// validateUsername is not supported on JS/WASM
|
||||
func validateUsername(_ string) error {
|
||||
return errNotSupported
|
||||
}
|
||||
908
client/ssh/server/user_utils_test.go
Normal file
908
client/ssh/server/user_utils_test.go
Normal file
@@ -0,0 +1,908 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test helper functions
|
||||
func createTestUser(username, uid, gid, homeDir string) *user.User {
|
||||
return &user.User{
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
Username: username,
|
||||
Name: username,
|
||||
HomeDir: homeDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Test dependency injection setup - injects platform dependencies to test real logic
|
||||
func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() {
|
||||
// Store originals
|
||||
originalGetCurrentUser := getCurrentUser
|
||||
originalLookupUser := lookupUser
|
||||
originalGetCurrentOS := getCurrentOS
|
||||
originalGetEuid := getEuid
|
||||
|
||||
// Reset caches to ensure clean test state
|
||||
|
||||
// Set test values - inject platform dependencies
|
||||
getCurrentUser = func() (*user.User, error) {
|
||||
return currentUser, currentUserErr
|
||||
}
|
||||
|
||||
lookupUser = func(username string) (*user.User, error) {
|
||||
if err, exists := lookupErrors[username]; exists {
|
||||
return nil, err
|
||||
}
|
||||
if userObj, exists := lookupUsers[username]; exists {
|
||||
return userObj, nil
|
||||
}
|
||||
return nil, errors.New("user: unknown user " + username)
|
||||
}
|
||||
|
||||
getCurrentOS = func() string {
|
||||
return os
|
||||
}
|
||||
|
||||
getEuid = func() int {
|
||||
return euid
|
||||
}
|
||||
|
||||
// Mock privilege detection based on the test user
|
||||
getIsProcessPrivileged = func() bool {
|
||||
if currentUser == nil {
|
||||
return false
|
||||
}
|
||||
// Check both username and SID for Windows systems
|
||||
if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) {
|
||||
return true
|
||||
}
|
||||
return isPrivilegedUsername(currentUser.Username)
|
||||
}
|
||||
|
||||
// Return cleanup function
|
||||
return func() {
|
||||
getCurrentUser = originalGetCurrentUser
|
||||
lookupUser = originalLookupUser
|
||||
getCurrentOS = originalGetCurrentOS
|
||||
getEuid = originalGetEuid
|
||||
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
// Reset caches after test
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
os string
|
||||
euid int
|
||||
currentUser *user.User
|
||||
requestedUsername string
|
||||
featureSupportsUserSwitch bool
|
||||
allowRoot bool
|
||||
lookupUsers map[string]*user.User
|
||||
expectedAllowed bool
|
||||
expectedRequiresSwitch bool
|
||||
}{
|
||||
{
|
||||
name: "linux_root_can_switch_to_alice",
|
||||
os: "linux",
|
||||
euid: 0, // Root process
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
},
|
||||
expectedAllowed: true,
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_fallback_to_current_user",
|
||||
os: "linux",
|
||||
euid: 1000, // Non-root process
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "bob",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Should fallback to current user (alice)
|
||||
expectedRequiresSwitch: false, // Fallback means no actual switching
|
||||
},
|
||||
{
|
||||
name: "windows_admin_can_switch_to_alice",
|
||||
os: "windows",
|
||||
euid: 1000, // Irrelevant on Windows
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: true,
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_no_fallback_hard_failure",
|
||||
os: "windows",
|
||||
euid: 1000, // Irrelevant on Windows
|
||||
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
|
||||
requestedUsername: "bob",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"),
|
||||
},
|
||||
expectedAllowed: true, // Let OS decide - deferred security check
|
||||
expectedRequiresSwitch: true, // Different user was requested
|
||||
},
|
||||
// Comprehensive test matrix: non-root linux with different allowRoot settings
|
||||
{
|
||||
name: "linux_non_root_request_root_allowRoot_false",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Fallback allows access regardless of root setting
|
||||
expectedRequiresSwitch: false, // Fallback case, no switching
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_request_root_allowRoot_true",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Should fallback to alice (non-privileged process)
|
||||
expectedRequiresSwitch: false, // Fallback means no actual switching
|
||||
},
|
||||
// Windows admin test matrix
|
||||
{
|
||||
name: "windows_admin_request_root_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_request_root_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Windows user switching should work like Unix
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
// Windows non-admin test matrix
|
||||
{
|
||||
name: "windows_non_admin_request_root_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence)
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_system_account_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_system_account_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // SYSTEM can switch to root
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_request_root_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Let OS decide - deferred security check
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Feature doesn't support user switching scenarios
|
||||
{
|
||||
name: "linux_root_feature_no_user_switching_same_user",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "root", // Same user
|
||||
featureSupportsUserSwitch: false,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Same user should work regardless of feature support
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
{
|
||||
name: "linux_root_feature_no_user_switching_different_user",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: false, // Feature doesn't support switching
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
},
|
||||
expectedAllowed: false, // Should deny because feature doesn't support switching
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Empty username (current user) scenarios
|
||||
{
|
||||
name: "linux_non_root_current_user_empty_username",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "", // Empty = current user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Current user should always work
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
{
|
||||
name: "linux_root_current_user_empty_username_root_not_allowed",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "", // Empty = current user (root)
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false, // Root not allowed
|
||||
expectedAllowed: false, // Should deny root even when it's current user
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
|
||||
// User not found scenarios
|
||||
{
|
||||
name: "linux_root_user_not_found",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "nonexistent",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{}, // No users defined = user not found
|
||||
expectedAllowed: false, // Should fail due to user not found
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Windows feature doesn't support user switching
|
||||
{
|
||||
name: "windows_admin_feature_no_user_switching_different_user",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: false, // Feature doesn't support switching
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: false, // Should deny because feature doesn't support switching
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Windows regular user scenarios (non-admin)
|
||||
{
|
||||
name: "windows_regular_user_same_user",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "alice", // Same user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: true, // Regular user accessing themselves should work
|
||||
expectedRequiresSwitch: false, // No switching for same user
|
||||
},
|
||||
{
|
||||
name: "windows_regular_user_empty_username",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "", // Empty = current user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Current user should always work
|
||||
expectedRequiresSwitch: false, // No switching for current user
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Inject platform dependencies to test real logic
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.expectedAllowed, result.Allowed)
|
||||
assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Fallback mechanism is Unix-specific")
|
||||
}
|
||||
|
||||
// Create test scenario where fallback should occur
|
||||
server := &Server{allowRootLogin: true}
|
||||
|
||||
// Mock dependencies to simulate non-privileged user
|
||||
originalGetCurrentUser := getCurrentUser
|
||||
originalGetIsProcessPrivileged := getIsProcessPrivileged
|
||||
|
||||
defer func() {
|
||||
getCurrentUser = originalGetCurrentUser
|
||||
getIsProcessPrivileged = originalGetIsProcessPrivileged
|
||||
|
||||
}()
|
||||
|
||||
// Set up mocks for fallback scenario
|
||||
getCurrentUser = func() (*user.User, error) {
|
||||
return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil
|
||||
}
|
||||
getIsProcessPrivileged = func() bool { return false } // Non-privileged
|
||||
|
||||
// Request different user - should fallback
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "alice",
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
// Verify fallback occurred
|
||||
assert.True(t, result.Allowed, "Should allow with fallback")
|
||||
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
|
||||
assert.Equal(t, "netbird", result.User.Username, "Should return current user")
|
||||
assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used")
|
||||
|
||||
// Key assertion: When UsedFallback is true, no privilege dropping should be needed
|
||||
// because all privilege checks have already been performed and we're using current user
|
||||
t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed",
|
||||
result.User.Username)
|
||||
}
|
||||
|
||||
func TestPrivilegedUsernameDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
platform string
|
||||
privileged bool
|
||||
}{
|
||||
// Unix/Linux tests
|
||||
{"unix_root", "root", "linux", true},
|
||||
{"unix_regular_user", "alice", "linux", false},
|
||||
{"unix_root_capital", "Root", "linux", false}, // Case-sensitive
|
||||
|
||||
// Windows tests
|
||||
{"windows_administrator", "Administrator", "windows", true},
|
||||
{"windows_system", "SYSTEM", "windows", true},
|
||||
{"windows_admin", "admin", "windows", true},
|
||||
{"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive
|
||||
{"windows_domain_admin", "DOMAIN\\Administrator", "windows", true},
|
||||
{"windows_email_admin", "admin@domain.com", "windows", true},
|
||||
{"windows_regular_user", "alice", "windows", false},
|
||||
{"windows_domain_user", "DOMAIN\\alice", "windows", false},
|
||||
{"windows_localsystem", "localsystem", "windows", true},
|
||||
{"windows_networkservice", "networkservice", "windows", true},
|
||||
{"windows_localservice", "localservice", "windows", true},
|
||||
|
||||
// Computer accounts (these depend on current user context in real implementation)
|
||||
{"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged
|
||||
{"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account
|
||||
|
||||
// Cross-platform
|
||||
{"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock the platform for this test
|
||||
cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
result := isPrivilegedUsername(tt.username)
|
||||
assert.Equal(t, tt.privileged, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsPrivilegedSIDDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sid string
|
||||
privileged bool
|
||||
description string
|
||||
}{
|
||||
// Well-known system accounts
|
||||
{"system_account", "S-1-5-18", true, "Local System (SYSTEM)"},
|
||||
{"local_service", "S-1-5-19", true, "Local Service"},
|
||||
{"network_service", "S-1-5-20", true, "Network Service"},
|
||||
{"administrators_group", "S-1-5-32-544", true, "Administrators group"},
|
||||
{"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"},
|
||||
|
||||
// Domain accounts
|
||||
{"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"},
|
||||
{"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"},
|
||||
{"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"},
|
||||
{"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"},
|
||||
|
||||
// Regular users
|
||||
{"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"},
|
||||
{"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"},
|
||||
{"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"},
|
||||
|
||||
// Groups that are not privileged
|
||||
{"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"},
|
||||
{"power_users", "S-1-5-32-547", false, "Power Users group"},
|
||||
|
||||
// Invalid SIDs
|
||||
{"malformed_sid", "S-1-5-invalid", false, "Malformed SID"},
|
||||
{"empty_sid", "", false, "Empty SID"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isWindowsPrivilegedSID(tt.sid)
|
||||
assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSameUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user1 string
|
||||
user2 string
|
||||
os string
|
||||
expected bool
|
||||
}{
|
||||
// Basic cases
|
||||
{"same_username", "alice", "alice", "linux", true},
|
||||
{"different_username", "alice", "bob", "linux", false},
|
||||
|
||||
// Linux (no domain processing)
|
||||
{"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false},
|
||||
{"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false},
|
||||
{"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true},
|
||||
|
||||
// Windows (with domain processing) - Note: parameter order is (requested, current, os, expected)
|
||||
{"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user
|
||||
{"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user
|
||||
{"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users
|
||||
{"windows_case_insensitive", "Alice", "alice", "windows", true},
|
||||
{"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up OS mock
|
||||
cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
result := isSameUser(tt.user1, tt.user2)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernameValidation_Unix(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix-specific username validation tests")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid usernames (Unix/POSIX)
|
||||
{"valid_alphanumeric", "user123", false, ""},
|
||||
{"valid_with_dots", "user.name", false, ""},
|
||||
{"valid_with_hyphens", "user-name", false, ""},
|
||||
{"valid_with_underscores", "user_name", false, ""},
|
||||
{"valid_uppercase", "UserName", false, ""},
|
||||
{"valid_starting_with_digit", "123user", false, ""},
|
||||
{"valid_starting_with_dot", ".hidden", false, ""},
|
||||
|
||||
// Invalid usernames (Unix/POSIX)
|
||||
{"empty_username", "", true, "username cannot be empty"},
|
||||
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
|
||||
{"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction
|
||||
{"username_with_spaces", "user name", true, "invalid characters"},
|
||||
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
|
||||
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
|
||||
{"username_with_pipe", "user|rm", true, "invalid characters"},
|
||||
{"username_with_ampersand", "user&rm", true, "invalid characters"},
|
||||
{"username_with_quotes", "user\"name", true, "invalid characters"},
|
||||
{"username_with_newline", "user\nname", true, "invalid characters"},
|
||||
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
|
||||
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
|
||||
{"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames
|
||||
{"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.username)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Should reject invalid username")
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "Should accept valid username")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernameValidation_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("Windows-specific username validation tests")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid usernames (Windows)
|
||||
{"valid_alphanumeric", "user123", false, ""},
|
||||
{"valid_with_dots", "user.name", false, ""},
|
||||
{"valid_with_hyphens", "user-name", false, ""},
|
||||
{"valid_with_underscores", "user_name", false, ""},
|
||||
{"valid_uppercase", "UserName", false, ""},
|
||||
{"valid_starting_with_digit", "123user", false, ""},
|
||||
{"valid_starting_with_dot", ".hidden", false, ""},
|
||||
{"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this
|
||||
{"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format
|
||||
{"valid_email_username", "user@domain.com", false, ""}, // Windows email format
|
||||
{"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format
|
||||
|
||||
// Invalid usernames (Windows)
|
||||
{"empty_username", "", true, "username cannot be empty"},
|
||||
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
|
||||
{"username_with_spaces", "user name", true, "invalid characters"},
|
||||
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
|
||||
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
|
||||
{"username_with_pipe", "user|rm", true, "invalid characters"},
|
||||
{"username_with_ampersand", "user&rm", true, "invalid characters"},
|
||||
{"username_with_quotes", "user\"name", true, "invalid characters"},
|
||||
{"username_with_newline", "user\nname", true, "invalid characters"},
|
||||
{"username_with_brackets", "user[name]", true, "invalid characters"},
|
||||
{"username_with_colon", "user:name", true, "invalid characters"},
|
||||
{"username_with_semicolon", "user;name", true, "invalid characters"},
|
||||
{"username_with_equals", "user=name", true, "invalid characters"},
|
||||
{"username_with_comma", "user,name", true, "invalid characters"},
|
||||
{"username_with_plus", "user+name", true, "invalid characters"},
|
||||
{"username_with_asterisk", "user*name", true, "invalid characters"},
|
||||
{"username_with_question", "user?name", true, "invalid characters"},
|
||||
{"username_with_angles", "user<name>", true, "invalid characters"},
|
||||
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
|
||||
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
|
||||
{"username_ending_with_period", "user.", true, "cannot end with a period"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.username)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Should reject invalid username")
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "Should accept valid username")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test real-world integration scenarios with actual platform capabilities
|
||||
func TestCheckPrivileges_RealWorldScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
feature string
|
||||
featureSupportsUserSwitch bool
|
||||
requestedUsername string
|
||||
allowRoot bool
|
||||
expectedBehaviorPattern string
|
||||
}{
|
||||
{"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"},
|
||||
{"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"},
|
||||
{"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"},
|
||||
{"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"},
|
||||
{"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock privileged environment to ensure consistent test behavior across environments
|
||||
cleanup := setupTestDependencies(
|
||||
createTestUser("root", "0", "0", "/root"), // Running as root
|
||||
nil,
|
||||
runtime.GOOS,
|
||||
0, // euid 0 (root)
|
||||
map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
"differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
|
||||
FeatureName: tt.feature,
|
||||
})
|
||||
|
||||
switch tt.expectedBehaviorPattern {
|
||||
case "should_allow_current_user":
|
||||
assert.True(t, result.Allowed, "Should allow current user access")
|
||||
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
|
||||
case "should_deny_root":
|
||||
assert.False(t, result.Allowed, "Should deny root when not allowed")
|
||||
assert.Contains(t, result.Error.Error(), "root", "Should mention root in error")
|
||||
case "should_deny_switching":
|
||||
assert.False(t, result.Allowed, "Should deny when feature doesn't support switching")
|
||||
assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual platform capabilities - no mocking
|
||||
func TestCheckPrivileges_ActualPlatform(t *testing.T) {
|
||||
// This test uses the REAL platform capabilities
|
||||
server := &Server{allowRootLogin: true}
|
||||
|
||||
// Test current user access - should always work
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "", // Current user
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.True(t, result.Allowed, "Current user should always be allowed")
|
||||
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
|
||||
assert.NotNil(t, result.User, "Should return current user")
|
||||
|
||||
// Test user switching capability based on actual platform
|
||||
actualIsPrivileged := isCurrentProcessPrivileged() // REAL check
|
||||
actualOS := runtime.GOOS // REAL check
|
||||
|
||||
t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v",
|
||||
actualOS, actualIsPrivileged, actualIsPrivileged)
|
||||
|
||||
// Test requesting different user
|
||||
result = server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "nonexistentuser",
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
switch {
|
||||
case actualOS == "windows":
|
||||
// Windows supports user switching but should fail on nonexistent user
|
||||
assert.False(t, result.Allowed, "Windows should deny nonexistent user")
|
||||
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
|
||||
assert.Contains(t, result.Error.Error(), "not found",
|
||||
"Should indicate user not found")
|
||||
case !actualIsPrivileged:
|
||||
// Non-privileged Unix processes should fallback to current user
|
||||
assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
|
||||
assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
|
||||
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
|
||||
assert.NotNil(t, result.User, "Should return current user")
|
||||
default:
|
||||
// Privileged Unix processes should attempt user lookup
|
||||
assert.False(t, result.Allowed, "Should fail due to nonexistent user")
|
||||
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
|
||||
assert.Contains(t, result.Error.Error(), "nonexistentuser",
|
||||
"Should indicate user not found")
|
||||
}
|
||||
}
|
||||
|
||||
// Test platform detection logic with dependency injection
|
||||
func TestPlatformLogic_DependencyInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
os string
|
||||
euid int
|
||||
currentUser *user.User
|
||||
expectedIsProcessPrivileged bool
|
||||
expectedSupportsUserSwitching bool
|
||||
}{
|
||||
{
|
||||
name: "linux_root_process",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
expectedIsProcessPrivileged: true,
|
||||
expectedSupportsUserSwitching: true,
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_process",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
expectedIsProcessPrivileged: false,
|
||||
expectedSupportsUserSwitching: false,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_process",
|
||||
os: "windows",
|
||||
euid: 1000, // euid ignored on Windows
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
expectedIsProcessPrivileged: true,
|
||||
expectedSupportsUserSwitching: true, // Windows supports user switching when privileged
|
||||
},
|
||||
{
|
||||
name: "windows_regular_process",
|
||||
os: "windows",
|
||||
euid: 1000, // euid ignored on Windows
|
||||
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
|
||||
expectedIsProcessPrivileged: false,
|
||||
expectedSupportsUserSwitching: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Inject platform dependencies and test REAL logic
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
// Test the actual functions with injected dependencies
|
||||
actualIsPrivileged := isCurrentProcessPrivileged()
|
||||
actualSupportsUserSwitching := actualIsPrivileged
|
||||
|
||||
assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged,
|
||||
"isCurrentProcessPrivileged() result mismatch")
|
||||
assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching,
|
||||
"supportsUserSwitching() result mismatch")
|
||||
|
||||
t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username)
|
||||
t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v",
|
||||
actualIsPrivileged, actualSupportsUserSwitching)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) {
|
||||
// Test Windows elevated user switching scenarios with simplified privilege logic
|
||||
tests := []struct {
|
||||
name string
|
||||
currentUser *user.User
|
||||
requestedUsername string
|
||||
allowRoot bool
|
||||
expectedAllowed bool
|
||||
expectedErrorContains string
|
||||
}{
|
||||
{
|
||||
name: "windows_admin_can_switch_to_alice",
|
||||
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_can_try_switch",
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"),
|
||||
requestedUsername: "bob",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Privilege check allows it, OS will reject during execution
|
||||
},
|
||||
{
|
||||
name: "windows_system_can_switch_to_alice",
|
||||
currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"),
|
||||
requestedUsername: "alice",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_root_not_allowed",
|
||||
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
allowRoot: false,
|
||||
expectedAllowed: false,
|
||||
expectedErrorContains: "privileged user login is disabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup test dependencies with Windows OS and specified privileges
|
||||
lookupUsers := map[string]*user.User{
|
||||
tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername),
|
||||
}
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.expectedAllowed, result.Allowed,
|
||||
"Privilege check result should match expected for %s", tt.name)
|
||||
|
||||
if !tt.expectedAllowed && tt.expectedErrorContains != "" {
|
||||
assert.NotNil(t, result.Error, "Should have error when not allowed")
|
||||
assert.Contains(t, result.Error.Error(), tt.expectedErrorContains,
|
||||
"Error should contain expected message")
|
||||
}
|
||||
|
||||
if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername {
|
||||
assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
8
client/ssh/server/userswitching_js.go
Normal file
8
client/ssh/server/userswitching_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
// enableUserSwitching is not supported on JS/WASM
|
||||
func enableUserSwitching() error {
|
||||
return errNotSupported
|
||||
}
|
||||
233
client/ssh/server/userswitching_unix.go
Normal file
233
client/ssh/server/userswitching_unix.go
Normal file
@@ -0,0 +1,233 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// POSIX portable filename character set regex: [a-zA-Z0-9._-]
|
||||
// First character cannot be hyphen (POSIX requirement)
|
||||
var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`)
|
||||
|
||||
// validateUsername validates that a username conforms to POSIX standards with security considerations
|
||||
func validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return errors.New("username cannot be empty")
|
||||
}
|
||||
|
||||
// POSIX allows up to 256 characters, but practical limit is 32 for compatibility
|
||||
if len(username) > 32 {
|
||||
return errors.New("username too long (max 32 characters)")
|
||||
}
|
||||
|
||||
if !posixUsernameRegex.MatchString(username) {
|
||||
return errors.New("username contains invalid characters (must match POSIX portable filename character set)")
|
||||
}
|
||||
|
||||
if username == "." || username == ".." {
|
||||
return fmt.Errorf("username cannot be '.' or '..'")
|
||||
}
|
||||
|
||||
// Warn if username is fully numeric (can cause issues with UID/username ambiguity)
|
||||
if isFullyNumeric(username) {
|
||||
log.Warnf("fully numeric username '%s' may cause issues with some commands", username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isFullyNumeric checks if username contains only digits
|
||||
func isFullyNumeric(username string) bool {
|
||||
for _, char := range username {
|
||||
if char < '0' || char > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// createPtyLoginCommand creates a Pty command using login for privileged processes
|
||||
func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||
loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get login command: %w", err)
|
||||
}
|
||||
|
||||
execCmd := exec.CommandContext(session.Context(), loginPath, args...)
|
||||
execCmd.Dir = localUser.HomeDir
|
||||
execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
return execCmd, nil
|
||||
}
|
||||
|
||||
// getLoginCmd returns the login command and args for privileged Pty user switching
|
||||
func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) {
|
||||
loginPath, err := exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("login command not available: %w", err)
|
||||
}
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", username, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// fileExists checks if a file exists (helper for login command logic)
|
||||
func (s *Server) fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// parseUserCredentials extracts numeric UID, GID, and supplementary groups
|
||||
func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) {
|
||||
uid64, err := strconv.ParseUint(localUser.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err)
|
||||
}
|
||||
uid := uint32(uid64)
|
||||
|
||||
gid64, err := strconv.ParseUint(localUser.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err)
|
||||
}
|
||||
gid := uint32(gid64)
|
||||
|
||||
groups, err := s.getSupplementaryGroups(localUser.Username)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err)
|
||||
groups = []uint32{gid}
|
||||
}
|
||||
|
||||
return uid, gid, groups, nil
|
||||
}
|
||||
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user
|
||||
func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
groupIDStrings, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", username, err)
|
||||
}
|
||||
|
||||
groups := make([]uint32, len(groupIDStrings))
|
||||
for i, gidStr := range groupIDStrings {
|
||||
gid64, err := strconv.ParseUint(gidStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err)
|
||||
}
|
||||
groups[i] = uint32(gid64)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||
// Returns the command and a cleanup function (no-op on Unix).
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
}
|
||||
|
||||
uid, gid, groups, err := s.parseUserCredentials(localUser)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
config := ExecutorConfig{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
|
||||
cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config)
|
||||
return cmd, func() {}, err
|
||||
}
|
||||
|
||||
// enableUserSwitching is a no-op on Unix systems
|
||||
func enableUserSwitching() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
|
||||
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, errors.New("no user in privilege result")
|
||||
}
|
||||
|
||||
if privilegeResult.UsedFallback {
|
||||
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
|
||||
}
|
||||
|
||||
return s.createPtyLoginCommand(localUser, ptyReq, session)
|
||||
}
|
||||
|
||||
// createDirectPtyCommand creates a direct Pty command without privilege dropping
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// preparePtyEnv prepares environment variables for Pty execution
|
||||
func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string {
|
||||
termType := ptyReq.Term
|
||||
if termType == "" {
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
274
client/ssh/server/userswitching_windows.go
Normal file
274
client/ssh/server/userswitching_windows.go
Normal file
@@ -0,0 +1,274 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
func validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return fmt.Errorf("username cannot be empty")
|
||||
}
|
||||
|
||||
usernameToValidate := extractUsernameFromDomain(username)
|
||||
|
||||
if err := validateUsernameLength(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateUsernameCharacters(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateUsernameFormat(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUsernameFromDomain extracts the username part from domain\username or username@domain format
|
||||
func extractUsernameFromDomain(username string) string {
|
||||
if idx := strings.LastIndex(username, `\`); idx != -1 {
|
||||
return username[idx+1:]
|
||||
}
|
||||
if idx := strings.Index(username, "@"); idx != -1 {
|
||||
return username[:idx]
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// validateUsernameLength checks if username length is within Windows limits
|
||||
func validateUsernameLength(username string) error {
|
||||
if len(username) > 20 {
|
||||
return fmt.Errorf("username too long (max 20 characters for Windows)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUsernameCharacters checks for invalid characters in Windows usernames
|
||||
func validateUsernameCharacters(username string) error {
|
||||
invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'}
|
||||
for _, char := range username {
|
||||
for _, invalid := range invalidChars {
|
||||
if char == invalid {
|
||||
return fmt.Errorf("username contains invalid characters")
|
||||
}
|
||||
}
|
||||
if char < 32 || char == 127 {
|
||||
return fmt.Errorf("username contains control characters")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUsernameFormat checks for invalid username formats and patterns
|
||||
func validateUsernameFormat(username string) error {
|
||||
if username == "." || username == ".." {
|
||||
return fmt.Errorf("username cannot be '.' or '..'")
|
||||
}
|
||||
|
||||
if strings.HasSuffix(username, ".") {
|
||||
return fmt.Errorf("username cannot end with a period")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
username, _ := s.parseUsername(localUser.Username)
|
||||
if err := validateUsername(username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||
}
|
||||
|
||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||
}
|
||||
|
||||
// createUserSwitchCommand creates a command with Windows user switching.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
if rawCmd != "" {
|
||||
command = rawCmd
|
||||
}
|
||||
|
||||
config := WindowsExecutorConfig{
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Interactive: interactive || (rawCmd == ""),
|
||||
}
|
||||
|
||||
dropper := NewPrivilegeDropper()
|
||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
if token != 0 {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// parseUsername extracts username and domain from a Windows username
|
||||
func (s *Server) parseUsername(fullUsername string) (username, domain string) {
|
||||
// Handle DOMAIN\username format
|
||||
if idx := strings.LastIndex(fullUsername, `\`); idx != -1 {
|
||||
domain = fullUsername[:idx]
|
||||
username = fullUsername[idx+1:]
|
||||
return username, domain
|
||||
}
|
||||
|
||||
// Handle username@domain format
|
||||
if username, domain, ok := strings.Cut(fullUsername, "@"); ok {
|
||||
return username, domain
|
||||
}
|
||||
|
||||
// Local user (no domain)
|
||||
return fullUsername, "."
|
||||
}
|
||||
|
||||
// hasPrivilege checks if the current process has a specific privilege
|
||||
func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) {
|
||||
var luid windows.LUID
|
||||
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
|
||||
return false, fmt.Errorf("lookup privilege value: %w", err)
|
||||
}
|
||||
|
||||
var returnLength uint32
|
||||
err := windows.GetTokenInformation(
|
||||
windows.Token(token),
|
||||
windows.TokenPrivileges,
|
||||
nil, // null buffer to get size
|
||||
0,
|
||||
&returnLength,
|
||||
)
|
||||
|
||||
if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
||||
return false, fmt.Errorf("get token information size: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, returnLength)
|
||||
err = windows.GetTokenInformation(
|
||||
windows.Token(token),
|
||||
windows.TokenPrivileges,
|
||||
&buffer[0],
|
||||
returnLength,
|
||||
&returnLength,
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get token information: %w", err)
|
||||
}
|
||||
|
||||
privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0]))
|
||||
|
||||
// Check if the privilege is present and enabled
|
||||
for i := uint32(0); i < privileges.PrivilegeCount; i++ {
|
||||
privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer(
|
||||
uintptr(unsafe.Pointer(&privileges.Privileges[0])) +
|
||||
uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}),
|
||||
))
|
||||
if privilege.Luid == luid {
|
||||
return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// enablePrivilege enables a specific privilege for the current process token
|
||||
// This is required because privileges like SeAssignPrimaryTokenPrivilege are present
|
||||
// but disabled by default, even for the SYSTEM account
|
||||
func enablePrivilege(token windows.Handle, privilegeName string) error {
|
||||
var luid windows.LUID
|
||||
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
|
||||
return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err)
|
||||
}
|
||||
|
||||
privileges := windows.Tokenprivileges{
|
||||
PrivilegeCount: 1,
|
||||
Privileges: [1]windows.LUIDAndAttributes{
|
||||
{
|
||||
Luid: luid,
|
||||
Attributes: windows.SE_PRIVILEGE_ENABLED,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := windows.AdjustTokenPrivileges(
|
||||
windows.Token(token),
|
||||
false,
|
||||
&privileges,
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err)
|
||||
}
|
||||
|
||||
hasPriv, err := hasPrivilege(token, privilegeName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err)
|
||||
}
|
||||
if !hasPriv {
|
||||
return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName)
|
||||
}
|
||||
|
||||
log.Debugf("Successfully enabled privilege %s for current process", privilegeName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// enableUserSwitching enables required privileges for Windows user switching
|
||||
func enableUserSwitching() error {
|
||||
process := windows.CurrentProcess()
|
||||
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(
|
||||
process,
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY,
|
||||
&token,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open process token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("Failed to close process token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil {
|
||||
return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err)
|
||||
}
|
||||
log.Infof("Windows user switching privileges enabled successfully")
|
||||
return nil
|
||||
}
|
||||
487
client/ssh/server/winpty/conpty.go
Normal file
487
client/ssh/server/winpty/conpty.go
Normal file
@@ -0,0 +1,487 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyEnvironment = errors.New("empty environment")
|
||||
)
|
||||
|
||||
const (
|
||||
extendedStartupInfoPresent = 0x00080000
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
procThreadAttributePseudoConsole = 0x00020016
|
||||
|
||||
PowerShellCommandFlag = "-Command"
|
||||
|
||||
errCloseInputRead = "close input read handle: %v"
|
||||
errCloseConPtyCleanup = "close ConPty handle during cleanup"
|
||||
)
|
||||
|
||||
// PtyConfig holds configuration for Pty execution.
|
||||
type PtyConfig struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
WorkingDir string
|
||||
}
|
||||
|
||||
// UserConfig holds user execution configuration.
|
||||
type UserConfig struct {
|
||||
Token windows.Handle
|
||||
Environment []string
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
|
||||
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
|
||||
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
|
||||
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
|
||||
)
|
||||
|
||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||
commandLine := buildCommandLine(args)
|
||||
|
||||
config := ExecutionConfig{
|
||||
Pty: ptyConfig,
|
||||
User: userConfig,
|
||||
Session: session,
|
||||
Context: ctx,
|
||||
}
|
||||
|
||||
return executeConPtyWithConfig(commandLine, config)
|
||||
}
|
||||
|
||||
// ExecutionConfig holds all execution configuration.
|
||||
type ExecutionConfig struct {
|
||||
Pty PtyConfig
|
||||
User UserConfig
|
||||
Session ssh.Session
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
// executeConPtyWithConfig creates ConPty and executes process with configuration.
|
||||
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
|
||||
ctx := config.Context
|
||||
session := config.Session
|
||||
width := config.Pty.Width
|
||||
height := config.Pty.Height
|
||||
userToken := config.User.Token
|
||||
userEnv := config.User.Environment
|
||||
workingDir := config.Pty.WorkingDir
|
||||
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty pipes: %w", err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(width, height, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty: %w", err)
|
||||
}
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(userToken)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("duplicate to primary token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(primaryToken); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
siEx, err := setupConPtyStartupInfo(hPty)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("setup startup info: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
|
||||
}()
|
||||
|
||||
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("create process as user with ConPty: %w", err)
|
||||
}
|
||||
defer closeProcessInfo(pi)
|
||||
|
||||
if err := windows.CloseHandle(inputRead); err != nil {
|
||||
log.Debugf(errCloseInputRead, err)
|
||||
}
|
||||
if err := windows.CloseHandle(outputWrite); err != nil {
|
||||
log.Debugf("close output write handle: %v", err)
|
||||
}
|
||||
|
||||
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process)
|
||||
}
|
||||
|
||||
// createConPtyPipes creates input/output pipes for ConPty.
|
||||
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
|
||||
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
|
||||
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
|
||||
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
|
||||
log.Debugf(errCloseInputRead, closeErr)
|
||||
}
|
||||
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
|
||||
log.Debugf("close input write handle: %v", closeErr)
|
||||
}
|
||||
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
|
||||
}
|
||||
|
||||
return inputRead, inputWrite, outputRead, outputWrite, nil
|
||||
}
|
||||
|
||||
// createConPty creates a Windows ConPty with the specified size and pipe handles.
|
||||
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
|
||||
size := windows.Coord{X: int16(width), Y: int16(height)}
|
||||
|
||||
var hPty windows.Handle
|
||||
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
|
||||
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
|
||||
}
|
||||
|
||||
return hPty, nil
|
||||
}
|
||||
|
||||
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
|
||||
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
|
||||
var siEx windows.StartupInfoEx
|
||||
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
|
||||
|
||||
var attrListSize uintptr
|
||||
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
|
||||
if ret == 0 && attrListSize == 0 {
|
||||
return nil, fmt.Errorf("get attribute list size")
|
||||
}
|
||||
|
||||
attrListBytes := make([]byte, attrListSize)
|
||||
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
|
||||
|
||||
ret, _, err := procInitializeProcThreadAttributeList.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
1,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&attrListSize)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("initialize attribute list: %w", err)
|
||||
}
|
||||
|
||||
ret, _, err = procUpdateProcThreadAttribute.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
0,
|
||||
procThreadAttributePseudoConsole,
|
||||
uintptr(hPty),
|
||||
unsafe.Sizeof(hPty),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("update thread attribute: %w", err)
|
||||
}
|
||||
|
||||
return &siEx, nil
|
||||
}
|
||||
|
||||
// createConPtyProcess creates the actual process with ConPty.
|
||||
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
|
||||
var pi windows.ProcessInformation
|
||||
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
|
||||
|
||||
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
|
||||
}
|
||||
|
||||
envPtr, err := convertEnvironmentToUTF16(userEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var workingDirPtr *uint16
|
||||
if workingDir != "" {
|
||||
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
|
||||
siEx.StartupInfo.StdInput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdOutput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
|
||||
|
||||
if userToken != windows.InvalidHandle {
|
||||
err = windows.CreateProcessAsUser(
|
||||
windows.Token(userToken),
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
} else {
|
||||
err = windows.CreateProcess(
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process: %w", err)
|
||||
}
|
||||
|
||||
return &pi, nil
|
||||
}
|
||||
|
||||
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
|
||||
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
|
||||
if len(userEnv) == 0 {
|
||||
// Return nil pointer for empty environment - Windows API will inherit parent environment
|
||||
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
|
||||
}
|
||||
|
||||
var envUTF16 []uint16
|
||||
for _, envVar := range userEnv {
|
||||
if envVar != "" {
|
||||
utf16Str, err := windows.UTF16FromString(envVar)
|
||||
if err != nil {
|
||||
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
|
||||
continue
|
||||
}
|
||||
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
}
|
||||
}
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
|
||||
if len(envUTF16) > 0 {
|
||||
return &envUTF16[0], nil
|
||||
}
|
||||
// Return nil pointer when no valid environment variables found
|
||||
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
|
||||
}
|
||||
|
||||
// duplicateToPrimaryToken converts an impersonation token to a primary token.
|
||||
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
|
||||
var primaryToken windows.Handle
|
||||
if err := windows.DuplicateTokenEx(
|
||||
windows.Token(token),
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityImpersonation,
|
||||
windows.TokenPrimary,
|
||||
(*windows.Token)(&primaryToken),
|
||||
); err != nil {
|
||||
return 0, fmt.Errorf("duplicate token: %w", err)
|
||||
}
|
||||
return primaryToken, nil
|
||||
}
|
||||
|
||||
// SessionExiter provides the Exit method for reporting process exit status.
|
||||
type SessionExiter interface {
|
||||
Exit(code int) error
|
||||
}
|
||||
|
||||
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
|
||||
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
|
||||
|
||||
processErr := waitForProcess(ctx, process)
|
||||
if processErr != nil {
|
||||
return processErr
|
||||
}
|
||||
|
||||
var exitCode uint32
|
||||
if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
|
||||
log.Debugf("get exit code: %v", err)
|
||||
} else {
|
||||
if err := session.Exit(int(exitCode)); err != nil {
|
||||
log.Debugf("report exit code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up in the original order after process completes
|
||||
if err := reader.Close(); err != nil {
|
||||
log.Debugf("close reader: %v", err)
|
||||
}
|
||||
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
if ret == 0 {
|
||||
log.Debugf("close ConPty handle: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := windows.CloseHandle(outputRead); err != nil {
|
||||
log.Debugf("close output read handle: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startIOBridging starts the I/O bridging goroutines.
|
||||
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
|
||||
wg.Add(2)
|
||||
|
||||
// Input: reader (SSH session) -> inputWrite (ConPty)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(inputWrite); err != nil {
|
||||
log.Debugf("close input write handle in goroutine: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
|
||||
log.Debugf("input copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Output: outputRead (ConPty) -> writer (SSH session)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
|
||||
log.Debugf("output copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// waitForProcess waits for process completion with context cancellation.
|
||||
func waitForProcess(ctx context.Context, process windows.Handle) error {
|
||||
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
|
||||
return fmt.Errorf("wait for process %d: %w", process, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildShellArgs builds shell arguments for ConPty execution.
|
||||
func buildShellArgs(shell, command string) []string {
|
||||
if command != "" {
|
||||
return []string{shell, PowerShellCommandFlag, command}
|
||||
}
|
||||
return []string{shell}
|
||||
}
|
||||
|
||||
// buildCommandLine builds a Windows command line from arguments using proper escaping.
|
||||
func buildCommandLine(args []string) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for i, arg := range args {
|
||||
if i > 0 {
|
||||
result.WriteString(" ")
|
||||
}
|
||||
result.WriteString(syscall.EscapeArg(arg))
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// closeHandles closes multiple Windows handles.
|
||||
func closeHandles(handles ...windows.Handle) {
|
||||
for _, handle := range handles {
|
||||
if handle != windows.InvalidHandle {
|
||||
if err := windows.CloseHandle(handle); err != nil {
|
||||
log.Debugf("close handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeProcessInfo closes process and thread handles.
|
||||
func closeProcessInfo(pi *windows.ProcessInformation) {
|
||||
if pi != nil {
|
||||
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||
log.Debugf("close process handle: %v", err)
|
||||
}
|
||||
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||
log.Debugf("close thread handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// windowsHandleReader wraps a Windows handle for reading.
|
||||
type windowsHandleReader struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
|
||||
var bytesRead uint32
|
||||
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Close() error {
|
||||
return windows.CloseHandle(r.handle)
|
||||
}
|
||||
|
||||
// windowsHandleWriter wraps a Windows handle for writing.
|
||||
type windowsHandleWriter struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
|
||||
var bytesWritten uint32
|
||||
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesWritten), nil
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Close() error {
|
||||
return windows.CloseHandle(w.handle)
|
||||
}
|
||||
290
client/ssh/server/winpty/conpty_test.go
Normal file
290
client/ssh/server/winpty/conpty_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestBuildShellArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shell string
|
||||
command string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Shell with command",
|
||||
shell: "powershell.exe",
|
||||
command: "Get-Process",
|
||||
expected: []string{"powershell.exe", "-Command", "Get-Process"},
|
||||
},
|
||||
{
|
||||
name: "CMD with command",
|
||||
shell: "cmd.exe",
|
||||
command: "dir",
|
||||
expected: []string{"cmd.exe", "-Command", "dir"},
|
||||
},
|
||||
{
|
||||
name: "Shell interactive",
|
||||
shell: "powershell.exe",
|
||||
command: "",
|
||||
expected: []string{"powershell.exe"},
|
||||
},
|
||||
{
|
||||
name: "CMD interactive",
|
||||
shell: "cmd.exe",
|
||||
command: "",
|
||||
expected: []string{"cmd.exe"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildShellArgs(tt.shell, tt.command)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple args",
|
||||
args: []string{"cmd.exe", "/c", "echo"},
|
||||
expected: "cmd.exe /c echo",
|
||||
},
|
||||
{
|
||||
name: "Args with spaces",
|
||||
args: []string{"Program Files\\app.exe", "arg with spaces"},
|
||||
expected: `"Program Files\app.exe" "arg with spaces"`,
|
||||
},
|
||||
{
|
||||
name: "Args with quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "hello world"`},
|
||||
expected: `cmd.exe /c "echo \"hello world\""`,
|
||||
},
|
||||
{
|
||||
name: "PowerShell calling PowerShell",
|
||||
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
|
||||
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
|
||||
},
|
||||
{
|
||||
name: "Complex nested quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
|
||||
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
|
||||
},
|
||||
{
|
||||
name: "Path with spaces and args",
|
||||
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
|
||||
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
|
||||
},
|
||||
{
|
||||
name: "Empty argument",
|
||||
args: []string{"cmd.exe", "/c", "echo", ""},
|
||||
expected: `cmd.exe /c echo ""`,
|
||||
},
|
||||
{
|
||||
name: "Argument with backslashes",
|
||||
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
|
||||
expected: `robocopy C:\Source\ C:\Dest\ /E`,
|
||||
},
|
||||
{
|
||||
name: "Empty args",
|
||||
args: []string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single arg with space",
|
||||
args: []string{"path with spaces"},
|
||||
expected: `"path with spaces"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildCommandLine(tt.args)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConPtyPipes(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
|
||||
// Verify all handles are valid
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
|
||||
|
||||
// Clean up handles
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
|
||||
func TestCreateConPty(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
require.NoError(t, err, "Should create ConPty successfully")
|
||||
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
|
||||
|
||||
// Clean up ConPty
|
||||
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
|
||||
}
|
||||
|
||||
func TestConvertEnvironmentToUTF16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userEnv []string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid environment variables",
|
||||
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty environment",
|
||||
userEnv: []string{},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Environment with empty strings",
|
||||
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := convertEnvironmentToUTF16(tt.userEnv)
|
||||
if tt.hasError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if len(tt.userEnv) == 0 {
|
||||
assert.Nil(t, result, "Empty environment should return nil")
|
||||
} else {
|
||||
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateToPrimaryToken(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping token tests in short mode")
|
||||
}
|
||||
|
||||
// Get current process token for testing
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
|
||||
require.NoError(t, err, "Should open current process token")
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
t.Logf("Failed to close token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
|
||||
require.NoError(t, err, "Should duplicate token to primary")
|
||||
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
|
||||
|
||||
// Clean up
|
||||
err = windows.CloseHandle(primaryToken)
|
||||
assert.NoError(t, err, "Should close primary token")
|
||||
}
|
||||
|
||||
func TestWindowsHandleReader(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Write test data
|
||||
testData := []byte("Hello, Windows Handle Reader!")
|
||||
var bytesWritten uint32
|
||||
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
|
||||
require.NoError(t, err, "Should write test data")
|
||||
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
|
||||
|
||||
// Close write handle to signal EOF
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
writeHandle = windows.InvalidHandle
|
||||
|
||||
// Test reading
|
||||
reader := &windowsHandleReader{handle: readHandle}
|
||||
buffer := make([]byte, len(testData))
|
||||
n, err := reader.Read(buffer)
|
||||
require.NoError(t, err, "Should read from handle")
|
||||
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read expected data")
|
||||
}
|
||||
|
||||
func TestWindowsHandleWriter(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Test writing
|
||||
testData := []byte("Hello, Windows Handle Writer!")
|
||||
writer := &windowsHandleWriter{handle: writeHandle}
|
||||
n, err := writer.Write(testData)
|
||||
require.NoError(t, err, "Should write to handle")
|
||||
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
|
||||
|
||||
// Close write handle
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was written by reading it back
|
||||
buffer := make([]byte, len(testData))
|
||||
var bytesRead uint32
|
||||
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
|
||||
require.NoError(t, err, "Should read back written data")
|
||||
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read back expected data")
|
||||
}
|
||||
|
||||
// BenchmarkConPtyCreation benchmarks ConPty creation performance
|
||||
func BenchmarkConPtyCreation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 {
|
||||
log.Debugf("ClosePseudoConsole failed: %v", err)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import "context"
|
||||
|
||||
// MockServer mocks ssh.Server
|
||||
type MockServer struct {
|
||||
Ctx context.Context
|
||||
StopFunc func() error
|
||||
StartFunc func() error
|
||||
AddAuthorizedKeyFunc func(peer, newKey string) error
|
||||
RemoveAuthorizedKeyFunc func(peer string)
|
||||
}
|
||||
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
func (srv *MockServer) RemoveAuthorizedKey(peer string) {
|
||||
if srv.RemoveAuthorizedKeyFunc == nil {
|
||||
return
|
||||
}
|
||||
srv.RemoveAuthorizedKeyFunc(peer)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error {
|
||||
if srv.AddAuthorizedKeyFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return srv.AddAuthorizedKeyFunc(peer, newKey)
|
||||
}
|
||||
|
||||
// Stop stops SSH server.
|
||||
func (srv *MockServer) Stop() error {
|
||||
if srv.StopFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return srv.StopFunc()
|
||||
}
|
||||
|
||||
// Start starts SSH server. Blocking
|
||||
func (srv *MockServer) Start() error {
|
||||
if srv.StartFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return srv.StartFunc()
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestServer_AddAuthorizedKey(t *testing.T) {
|
||||
key, err := GeneratePrivateKey(ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server, err := newDefaultServer(key, "localhost:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// add multiple keys
|
||||
keys := map[string][]byte{}
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := GeneratePrivateKey(ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remotePubKey, err := GeneratePublicKey(remotePrivKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
keys[peer] = remotePubKey
|
||||
}
|
||||
|
||||
// make sure that all keys have been added
|
||||
for peer, remotePubKey := range keys {
|
||||
k, ok := server.authorizedKeys[peer]
|
||||
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
|
||||
|
||||
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k))))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestServer_RemoveAuthorizedKey(t *testing.T) {
|
||||
key, err := GeneratePrivateKey(ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server, err := newDefaultServer(key, "localhost:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
remotePrivKey, err := GeneratePrivateKey(ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remotePubKey, err := GeneratePublicKey(remotePrivKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
server.RemoveAuthorizedKey("remotePeer")
|
||||
|
||||
_, ok := server.authorizedKeys["remotePeer"]
|
||||
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
|
||||
}
|
||||
|
||||
func TestServer_PubKeyHandler(t *testing.T) {
|
||||
key, err := GeneratePrivateKey(ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server, err := newDefaultServer(key, "localhost:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var keys []ssh.PublicKey
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := GeneratePrivateKey(ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
remotePubKey, err := GeneratePublicKey(remotePrivKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
keys = append(keys, remoteParsedPubKey)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
accepted := server.publicKeyHandler(nil, key)
|
||||
|
||||
assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key)))
|
||||
}
|
||||
|
||||
}
|
||||
@@ -32,9 +32,8 @@ const RSA KeyType = "rsa"
|
||||
// RSAKeySize is a size of newly generated RSA key
|
||||
const RSAKeySize = 2048
|
||||
|
||||
// GeneratePrivateKey creates RSA Private Key of specified byte size
|
||||
// GeneratePrivateKey creates a private key of the specified type.
|
||||
func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
|
||||
|
||||
var key crypto.Signer
|
||||
var err error
|
||||
switch keyType {
|
||||
@@ -59,7 +58,7 @@ func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
|
||||
return pemBytes, nil
|
||||
}
|
||||
|
||||
// GeneratePublicKey returns the public part of the private key
|
||||
// GeneratePublicKey returns the public part of the private key.
|
||||
func GeneratePublicKey(key []byte) ([]byte, error) {
|
||||
signer, err := gossh.ParsePrivateKey(key)
|
||||
if err != nil {
|
||||
@@ -70,20 +69,17 @@ func GeneratePublicKey(key []byte) ([]byte, error) {
|
||||
return []byte(strKey), nil
|
||||
}
|
||||
|
||||
// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format
|
||||
// EncodePrivateKeyToPEM encodes a private key to PEM format.
|
||||
func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) {
|
||||
mk, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// pem.Block
|
||||
privBlock := pem.Block{
|
||||
Type: "PRIVATE KEY",
|
||||
Bytes: mk,
|
||||
}
|
||||
|
||||
// Private key in PEM format
|
||||
privatePEM := pem.EncodeToMemory(&privBlock)
|
||||
return privatePEM, nil
|
||||
}
|
||||
172
client/ssh/testutil/user_helpers.go
Normal file
172
client/ssh/testutil/user_helpers.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testCreatedUsers = make(map[string]bool)
|
||||
var testUsersToCleanup []string
|
||||
|
||||
// GetTestUsername returns an appropriate username for testing
|
||||
func GetTestUsername(t *testing.T) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
if IsSystemAccount(currentUser.Username) {
|
||||
if IsCI() {
|
||||
if testUser := GetOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
} else {
|
||||
if _, err := user.Lookup("Administrator"); err == nil {
|
||||
return "Administrator"
|
||||
}
|
||||
if testUser := GetOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
// IsCI checks if we're running in a CI environment
|
||||
func IsCI() bool {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSystemAccount checks if the user is a system account that can't authenticate
|
||||
func IsSystemAccount(username string) bool {
|
||||
systemAccounts := []string{
|
||||
"system",
|
||||
"NT AUTHORITY\\SYSTEM",
|
||||
"NT AUTHORITY\\LOCAL SERVICE",
|
||||
"NT AUTHORITY\\NETWORK SERVICE",
|
||||
}
|
||||
|
||||
for _, sysAccount := range systemAccounts {
|
||||
if strings.EqualFold(username, sysAccount) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RegisterTestUserCleanup registers a test user for cleanup
|
||||
func RegisterTestUserCleanup(username string) {
|
||||
if !testCreatedUsers[username] {
|
||||
testCreatedUsers[username] = true
|
||||
testUsersToCleanup = append(testUsersToCleanup, username)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupTestUsers removes all created test users
|
||||
func CleanupTestUsers() {
|
||||
for _, username := range testUsersToCleanup {
|
||||
RemoveWindowsTestUser(username)
|
||||
}
|
||||
testUsersToCleanup = nil
|
||||
testCreatedUsers = make(map[string]bool)
|
||||
}
|
||||
|
||||
// GetOrCreateTestUser creates a test user on Windows if needed
|
||||
func GetOrCreateTestUser(t *testing.T) string {
|
||||
testUsername := "netbird-test-user"
|
||||
|
||||
if _, err := user.Lookup(testUsername); err == nil {
|
||||
return testUsername
|
||||
}
|
||||
|
||||
if CreateWindowsTestUser(t, testUsername) {
|
||||
RegisterTestUserCleanup(testUsername)
|
||||
return testUsername
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
|
||||
func RemoveWindowsTestUser(username string) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||
Write-Output "User removed successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*cannot be found*") {
|
||||
Write-Output "User not found (already removed)"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
}
|
||||
}
|
||||
`, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||
} else {
|
||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWindowsTestUser creates a local user on Windows using PowerShell
|
||||
func CreateWindowsTestUser(t *testing.T, username string) bool {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||
Write-Output "User created successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*already exists*") {
|
||||
Write-Output "User already exists"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
`, username, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
t.Logf("Test user creation result: %s", string(output))
|
||||
return true
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func setWinSize(file *os.File, width, height int) {
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build linux || darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func setWinSize(file *os.File, width, height int) {
|
||||
syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint
|
||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0})))
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func setWinSize(file *os.File, width, height int) {
|
||||
|
||||
}
|
||||
@@ -81,6 +81,18 @@ type NsServerGroupStateOutput struct {
|
||||
Error string `json:"error" yaml:"error"`
|
||||
}
|
||||
|
||||
type SSHSessionOutput struct {
|
||||
Username string `json:"username" yaml:"username"`
|
||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||
Command string `json:"command" yaml:"command"`
|
||||
JWTUsername string `json:"jwtUsername,omitempty" yaml:"jwtUsername,omitempty"`
|
||||
}
|
||||
|
||||
type SSHServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -100,6 +112,7 @@ type OutputOverview struct {
|
||||
Events []SystemEventOutput `json:"events" yaml:"events"`
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
}
|
||||
|
||||
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
|
||||
@@ -121,6 +134,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
|
||||
overview := OutputOverview{
|
||||
Peers: peersOverview,
|
||||
@@ -141,6 +155,7 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
|
||||
Events: mapEvents(pbFullStatus.GetEvents()),
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: profName,
|
||||
SSHServerState: sshServerOverview,
|
||||
}
|
||||
|
||||
if anon {
|
||||
@@ -190,6 +205,30 @@ func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
|
||||
return mappedNSGroups
|
||||
}
|
||||
|
||||
func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
||||
if sshServerState == nil {
|
||||
return SSHServerStateOutput{
|
||||
Enabled: false,
|
||||
Sessions: []SSHSessionOutput{},
|
||||
}
|
||||
}
|
||||
|
||||
sessions := make([]SSHSessionOutput, 0, len(sshServerState.GetSessions()))
|
||||
for _, session := range sshServerState.GetSessions() {
|
||||
sessions = append(sessions, SSHSessionOutput{
|
||||
Username: session.GetUsername(),
|
||||
RemoteAddress: session.GetRemoteAddress(),
|
||||
Command: session.GetCommand(),
|
||||
JWTUsername: session.GetJwtUsername(),
|
||||
})
|
||||
}
|
||||
|
||||
return SSHServerStateOutput{
|
||||
Enabled: sshServerState.GetEnabled(),
|
||||
Sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
func mapPeers(
|
||||
peers []*proto.PeerState,
|
||||
statusFilter string,
|
||||
@@ -300,7 +339,7 @@ func ParseToYAML(overview OutputOverview) (string, error) {
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
@@ -405,6 +444,41 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
lazyConnectionEnabledStatus = "true"
|
||||
}
|
||||
|
||||
sshServerStatus := "Disabled"
|
||||
if overview.SSHServerState.Enabled {
|
||||
sessionCount := len(overview.SSHServerState.Sessions)
|
||||
if sessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if sessionCount > 1 {
|
||||
sessionWord = "sessions"
|
||||
}
|
||||
sshServerStatus = fmt.Sprintf("Enabled (%d active %s)", sessionCount, sessionWord)
|
||||
} else {
|
||||
sshServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
if showSSHSessions && sessionCount > 0 {
|
||||
for _, session := range overview.SSHServerState.Sessions {
|
||||
var sessionDisplay string
|
||||
if session.JWTUsername != "" {
|
||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
||||
session.JWTUsername,
|
||||
session.RemoteAddress,
|
||||
session.Username,
|
||||
session.Command,
|
||||
)
|
||||
} else {
|
||||
sessionDisplay = fmt.Sprintf("[%s@%s] %s",
|
||||
session.Username,
|
||||
session.RemoteAddress,
|
||||
session.Command,
|
||||
)
|
||||
}
|
||||
sshServerStatus += "\n " + sessionDisplay
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
@@ -428,6 +502,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
"Interface type: %s\n"+
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"Forwarding rules: %d\n"+
|
||||
"Peers count: %s\n",
|
||||
@@ -444,6 +519,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
networks,
|
||||
overview.NumberOfForwardingRules,
|
||||
peersCountString,
|
||||
@@ -454,7 +530,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(overview.Events)
|
||||
summary := ParseGeneralSummary(overview, true, true, true)
|
||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
@@ -746,4 +822,13 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
event.Metadata[k] = a.AnonymizeString(v)
|
||||
}
|
||||
}
|
||||
|
||||
for i, session := range overview.SSHServerState.Sessions {
|
||||
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
} else {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
|
||||
}
|
||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,6 +231,10 @@ var overview = OutputOverview{
|
||||
Networks: []string{
|
||||
"10.10.0.0/24",
|
||||
},
|
||||
SSHServerState: SSHServerStateOutput{
|
||||
Enabled: false,
|
||||
Sessions: []SSHSessionOutput{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -385,7 +389,11 @@ func TestParsingToJSON(t *testing.T) {
|
||||
],
|
||||
"events": [],
|
||||
"lazyConnectionEnabled": false,
|
||||
"profileName":""
|
||||
"profileName":"",
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
|
||||
@@ -488,6 +496,9 @@ dnsServers:
|
||||
events: []
|
||||
lazyConnectionEnabled: false
|
||||
profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -554,6 +565,7 @@ NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Forwarding rules: 0
|
||||
Peers count: 2/2 Connected
|
||||
@@ -563,7 +575,7 @@ Peers count: 2/2 Connected
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := ParseGeneralSummary(overview, false, false, false)
|
||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
@@ -578,6 +590,7 @@ NetBird IP: 192.168.178.100/16
|
||||
Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Forwarding rules: 0
|
||||
Peers count: 2/2 Connected
|
||||
|
||||
@@ -72,6 +72,12 @@ type Info struct {
|
||||
BlockInbound bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
EnableSSHRoot bool
|
||||
EnableSSHSFTP bool
|
||||
EnableSSHLocalPortForwarding bool
|
||||
EnableSSHRemotePortForwarding bool
|
||||
DisableSSHAuth bool
|
||||
}
|
||||
|
||||
func (i *Info) SetFlags(
|
||||
@@ -79,6 +85,8 @@ func (i *Info) SetFlags(
|
||||
serverSSHAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
disableSSHAuth *bool,
|
||||
) {
|
||||
i.RosenpassEnabled = rosenpassEnabled
|
||||
i.RosenpassPermissive = rosenpassPermissive
|
||||
@@ -94,6 +102,22 @@ func (i *Info) SetFlags(
|
||||
i.BlockInbound = blockInbound
|
||||
|
||||
i.LazyConnectionEnabled = lazyConnectionEnabled
|
||||
|
||||
if enableSSHRoot != nil {
|
||||
i.EnableSSHRoot = *enableSSHRoot
|
||||
}
|
||||
if enableSSHSFTP != nil {
|
||||
i.EnableSSHSFTP = *enableSSHSFTP
|
||||
}
|
||||
if enableSSHLocalPortForwarding != nil {
|
||||
i.EnableSSHLocalPortForwarding = *enableSSHLocalPortForwarding
|
||||
}
|
||||
if enableSSHRemotePortForwarding != nil {
|
||||
i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
|
||||
}
|
||||
if disableSSHAuth != nil {
|
||||
i.DisableSSHAuth = *disableSSHAuth
|
||||
}
|
||||
}
|
||||
|
||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||
|
||||
@@ -55,6 +55,7 @@ const (
|
||||
|
||||
const (
|
||||
censoredPreSharedKey = "**********"
|
||||
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -265,25 +266,38 @@ type serviceClient struct {
|
||||
iMTU *widget.Entry
|
||||
|
||||
// switch elements for settings form
|
||||
sRosenpassPermissive *widget.Check
|
||||
sNetworkMonitor *widget.Check
|
||||
sDisableDNS *widget.Check
|
||||
sDisableClientRoutes *widget.Check
|
||||
sDisableServerRoutes *widget.Check
|
||||
sBlockLANAccess *widget.Check
|
||||
sRosenpassPermissive *widget.Check
|
||||
sNetworkMonitor *widget.Check
|
||||
sDisableDNS *widget.Check
|
||||
sDisableClientRoutes *widget.Check
|
||||
sDisableServerRoutes *widget.Check
|
||||
sBlockLANAccess *widget.Check
|
||||
sEnableSSHRoot *widget.Check
|
||||
sEnableSSHSFTP *widget.Check
|
||||
sEnableSSHLocalPortForward *widget.Check
|
||||
sEnableSSHRemotePortForward *widget.Check
|
||||
sDisableSSHAuth *widget.Check
|
||||
iSSHJWTCacheTTL *widget.Entry
|
||||
|
||||
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
||||
managementURL string
|
||||
preSharedKey string
|
||||
RosenpassPermissive bool
|
||||
interfaceName string
|
||||
interfacePort int
|
||||
mtu uint16
|
||||
networkMonitor bool
|
||||
disableDNS bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
blockLANAccess bool
|
||||
managementURL string
|
||||
preSharedKey string
|
||||
|
||||
RosenpassPermissive bool
|
||||
interfaceName string
|
||||
interfacePort int
|
||||
mtu uint16
|
||||
networkMonitor bool
|
||||
disableDNS bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
blockLANAccess bool
|
||||
enableSSHRoot bool
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
|
||||
connected bool
|
||||
update *version.Update
|
||||
@@ -435,18 +449,22 @@ func (s *serviceClient) showSettingsUI() {
|
||||
s.sDisableClientRoutes = widget.NewCheck("This peer won't route traffic to other peers", nil)
|
||||
s.sDisableServerRoutes = widget.NewCheck("This peer won't act as router for others", nil)
|
||||
s.sBlockLANAccess = widget.NewCheck("Blocks local network access when used as exit node", nil)
|
||||
s.sEnableSSHRoot = widget.NewCheck("Enable SSH Root Login", nil)
|
||||
s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
|
||||
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
|
||||
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
|
||||
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
|
||||
s.iSSHJWTCacheTTL = widget.NewEntry()
|
||||
|
||||
s.wSettings.SetContent(s.getSettingsForm())
|
||||
s.wSettings.Resize(fyne.NewSize(600, 500))
|
||||
s.wSettings.Resize(fyne.NewSize(600, 400))
|
||||
s.wSettings.SetFixedSize(true)
|
||||
|
||||
s.getSrvConfig()
|
||||
s.wSettings.Show()
|
||||
}
|
||||
|
||||
// getSettingsForm to embed it into settings window.
|
||||
func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
|
||||
func (s *serviceClient) getConnectionForm() *widget.Form {
|
||||
var activeProfName string
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
@@ -457,153 +475,277 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
return &widget.Form{
|
||||
Items: []*widget.FormItem{
|
||||
{Text: "Profile", Widget: widget.NewLabel(activeProfName)},
|
||||
{Text: "Management URL", Widget: s.iMngURL},
|
||||
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
|
||||
{Text: "Quantum-Resistance", Widget: s.sRosenpassPermissive},
|
||||
{Text: "Interface Name", Widget: s.iInterfaceName},
|
||||
{Text: "Interface Port", Widget: s.iInterfacePort},
|
||||
{Text: "MTU", Widget: s.iMTU},
|
||||
{Text: "Management URL", Widget: s.iMngURL},
|
||||
{Text: "Pre-shared Key", Widget: s.iPreSharedKey},
|
||||
{Text: "Log File", Widget: s.iLogFile},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) saveSettings() {
|
||||
// Check if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get features from daemon: %v", err)
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableUpdateSettings {
|
||||
log.Warn("Configuration updates are disabled by daemon")
|
||||
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.validateSettings(); err != nil {
|
||||
dialog.ShowError(err, s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
port, mtu, err := s.parseNumericSettings()
|
||||
if err != nil {
|
||||
dialog.ShowError(err, s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
||||
|
||||
if s.hasSettingsChanged(iMngURL, port, mtu) {
|
||||
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
|
||||
dialog.ShowError(err, s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.wSettings.Close()
|
||||
}
|
||||
|
||||
func (s *serviceClient) validateSettings() error {
|
||||
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
|
||||
return fmt.Errorf("Invalid Pre-shared Key Value")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
||||
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("Invalid interface port")
|
||||
}
|
||||
if port < 1 || port > 65535 {
|
||||
return 0, 0, errors.New("Invalid interface port: out of range 1-65535")
|
||||
}
|
||||
|
||||
var mtu int64
|
||||
mtuText := strings.TrimSpace(s.iMTU.Text)
|
||||
if mtuText != "" {
|
||||
mtu, err = strconv.ParseInt(mtuText, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("Invalid MTU value")
|
||||
}
|
||||
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
|
||||
return 0, 0, fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU)
|
||||
}
|
||||
}
|
||||
|
||||
return port, mtu, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool {
|
||||
return s.managementURL != iMngURL ||
|
||||
s.preSharedKey != s.iPreSharedKey.Text ||
|
||||
s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
|
||||
s.interfaceName != s.iInterfaceName.Text ||
|
||||
s.interfacePort != int(port) ||
|
||||
s.mtu != uint16(mtu) ||
|
||||
s.networkMonitor != s.sNetworkMonitor.Checked ||
|
||||
s.disableDNS != s.sDisableDNS.Checked ||
|
||||
s.disableClientRoutes != s.sDisableClientRoutes.Checked ||
|
||||
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
|
||||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
|
||||
s.hasSSHChanges()
|
||||
}
|
||||
|
||||
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
|
||||
s.managementURL = iMngURL
|
||||
s.preSharedKey = s.iPreSharedKey.Text
|
||||
s.mtu = uint16(mtu)
|
||||
|
||||
req, err := s.buildSetConfigRequest(iMngURL, port, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build config request: %w", err)
|
||||
}
|
||||
|
||||
if err := s.sendConfigUpdate(req); err != nil {
|
||||
return fmt.Errorf("set configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (*proto.SetConfigRequest, error) {
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
}
|
||||
|
||||
if iMngURL != "" {
|
||||
req.ManagementUrl = iMngURL
|
||||
}
|
||||
|
||||
req.RosenpassPermissive = &s.sRosenpassPermissive.Checked
|
||||
req.InterfaceName = &s.iInterfaceName.Text
|
||||
req.WireguardPort = &port
|
||||
if mtu > 0 {
|
||||
req.Mtu = &mtu
|
||||
}
|
||||
|
||||
req.NetworkMonitor = &s.sNetworkMonitor.Checked
|
||||
req.DisableDns = &s.sDisableDNS.Checked
|
||||
req.DisableClientRoutes = &s.sDisableClientRoutes.Checked
|
||||
req.DisableServerRoutes = &s.sDisableServerRoutes.Checked
|
||||
req.BlockLanAccess = &s.sBlockLANAccess.Checked
|
||||
|
||||
req.EnableSSHRoot = &s.sEnableSSHRoot.Checked
|
||||
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
|
||||
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
|
||||
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
|
||||
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
|
||||
|
||||
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
|
||||
if sshJWTCacheTTLText != "" {
|
||||
sshJWTCacheTTL, err := strconv.ParseInt(sshJWTCacheTTLText, 10, 32)
|
||||
if err != nil {
|
||||
return nil, errors.New("Invalid SSH JWT Cache TTL value")
|
||||
}
|
||||
if sshJWTCacheTTL < 0 || sshJWTCacheTTL > maxSSHJWTCacheTTL {
|
||||
return nil, fmt.Errorf("SSH JWT Cache TTL must be between 0 and %d seconds", maxSSHJWTCacheTTL)
|
||||
}
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) sendConfigUpdate(req *proto.SetConfigRequest) error {
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get client: %w", err)
|
||||
}
|
||||
|
||||
_, err = conn.SetConfig(s.ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set config: %w", err)
|
||||
}
|
||||
|
||||
// Reconnect if connected to apply the new settings
|
||||
go func() {
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
// run down & up
|
||||
_, err = conn.Down(s.ctx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
|
||||
connectionForm := s.getConnectionForm()
|
||||
networkForm := s.getNetworkForm()
|
||||
sshForm := s.getSSHForm()
|
||||
tabs := container.NewAppTabs(
|
||||
container.NewTabItem("Connection", connectionForm),
|
||||
container.NewTabItem("Network", networkForm),
|
||||
container.NewTabItem("SSH", sshForm),
|
||||
)
|
||||
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
|
||||
saveButton.Importance = widget.HighImportance
|
||||
cancelButton := widget.NewButtonWithIcon("Cancel", theme.CancelIcon(), func() {
|
||||
s.wSettings.Close()
|
||||
})
|
||||
buttonContainer := container.NewHBox(
|
||||
layout.NewSpacer(),
|
||||
cancelButton,
|
||||
saveButton,
|
||||
)
|
||||
return container.NewBorder(nil, buttonContainer, nil, nil, tabs)
|
||||
}
|
||||
|
||||
func (s *serviceClient) getNetworkForm() *widget.Form {
|
||||
return &widget.Form{
|
||||
Items: []*widget.FormItem{
|
||||
{Text: "Network Monitor", Widget: s.sNetworkMonitor},
|
||||
{Text: "Disable DNS", Widget: s.sDisableDNS},
|
||||
{Text: "Disable Client Routes", Widget: s.sDisableClientRoutes},
|
||||
{Text: "Disable Server Routes", Widget: s.sDisableServerRoutes},
|
||||
{Text: "Disable LAN Access", Widget: s.sBlockLANAccess},
|
||||
},
|
||||
SubmitText: "Save",
|
||||
OnSubmit: func() {
|
||||
// Check if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get features from daemon: %v", err)
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableUpdateSettings {
|
||||
log.Warn("Configuration updates are disabled by daemon")
|
||||
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
// validate preSharedKey if it added
|
||||
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
|
||||
dialog.ShowError(fmt.Errorf("Invalid Pre-shared Key Value"), s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
|
||||
if err != nil {
|
||||
dialog.ShowError(errors.New("Invalid interface port"), s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
var mtu int64
|
||||
mtuText := strings.TrimSpace(s.iMTU.Text)
|
||||
if mtuText != "" {
|
||||
var err error
|
||||
mtu, err = strconv.ParseInt(mtuText, 10, 64)
|
||||
if err != nil {
|
||||
dialog.ShowError(errors.New("Invalid MTU value"), s.wSettings)
|
||||
return
|
||||
}
|
||||
if mtu < iface.MinMTU || mtu > iface.MaxMTU {
|
||||
dialog.ShowError(fmt.Errorf("MTU must be between %d and %d bytes", iface.MinMTU, iface.MaxMTU), s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
||||
|
||||
defer s.wSettings.Close()
|
||||
|
||||
// Check if any settings have changed
|
||||
if s.managementURL != iMngURL || s.preSharedKey != s.iPreSharedKey.Text ||
|
||||
s.RosenpassPermissive != s.sRosenpassPermissive.Checked ||
|
||||
s.interfaceName != s.iInterfaceName.Text || s.interfacePort != int(port) ||
|
||||
s.mtu != uint16(mtu) ||
|
||||
s.networkMonitor != s.sNetworkMonitor.Checked ||
|
||||
s.disableDNS != s.sDisableDNS.Checked ||
|
||||
s.disableClientRoutes != s.sDisableClientRoutes.Checked ||
|
||||
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
|
||||
s.blockLANAccess != s.sBlockLANAccess.Checked {
|
||||
|
||||
s.managementURL = iMngURL
|
||||
s.preSharedKey = s.iPreSharedKey.Text
|
||||
s.mtu = uint16(mtu)
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Errorf("get current user: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var req proto.SetConfigRequest
|
||||
req.ProfileName = activeProf.Name
|
||||
req.Username = currUser.Username
|
||||
|
||||
if iMngURL != "" {
|
||||
req.ManagementUrl = iMngURL
|
||||
}
|
||||
|
||||
req.RosenpassPermissive = &s.sRosenpassPermissive.Checked
|
||||
req.InterfaceName = &s.iInterfaceName.Text
|
||||
req.WireguardPort = &port
|
||||
if mtu > 0 {
|
||||
req.Mtu = &mtu
|
||||
}
|
||||
req.NetworkMonitor = &s.sNetworkMonitor.Checked
|
||||
req.DisableDns = &s.sDisableDNS.Checked
|
||||
req.DisableClientRoutes = &s.sDisableClientRoutes.Checked
|
||||
req.DisableServerRoutes = &s.sDisableServerRoutes.Checked
|
||||
req.BlockLanAccess = &s.sBlockLANAccess.Checked
|
||||
|
||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||
}
|
||||
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to connect to the service: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
_, err = conn.SetConfig(s.ctx, &req)
|
||||
if err != nil {
|
||||
log.Errorf("set config: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to set configuration: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to get service status: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
// run down & up
|
||||
_, err = conn.Down(s.ctx, &proto.DownRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("Failed to reconnect: %v", err), s.wSettings)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
},
|
||||
OnCancel: func() {
|
||||
s.wSettings.Close()
|
||||
func (s *serviceClient) getSSHForm() *widget.Form {
|
||||
return &widget.Form{
|
||||
Items: []*widget.FormItem{
|
||||
{Text: "Enable SSH Root Login", Widget: s.sEnableSSHRoot},
|
||||
{Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
|
||||
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
|
||||
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
|
||||
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
|
||||
{Text: "JWT Cache TTL (seconds, 0=disabled)", Widget: s.iSSHJWTCacheTTL},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) hasSSHChanges() bool {
|
||||
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
|
||||
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
|
||||
val, err := strconv.Atoi(text)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
currentSSHJWTCacheTTL = val
|
||||
}
|
||||
|
||||
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
|
||||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
|
||||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
|
||||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
|
||||
s.disableSSHAuth != s.sDisableSSHAuth.Checked ||
|
||||
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
|
||||
}
|
||||
|
||||
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
@@ -1123,6 +1265,25 @@ func (s *serviceClient) getSrvConfig() {
|
||||
s.disableServerRoutes = cfg.DisableServerRoutes
|
||||
s.blockLANAccess = cfg.BlockLANAccess
|
||||
|
||||
if cfg.EnableSSHRoot != nil {
|
||||
s.enableSSHRoot = *cfg.EnableSSHRoot
|
||||
}
|
||||
if cfg.EnableSSHSFTP != nil {
|
||||
s.enableSSHSFTP = *cfg.EnableSSHSFTP
|
||||
}
|
||||
if cfg.EnableSSHLocalPortForwarding != nil {
|
||||
s.enableSSHLocalPortForward = *cfg.EnableSSHLocalPortForwarding
|
||||
}
|
||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||
s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
|
||||
}
|
||||
if cfg.DisableSSHAuth != nil {
|
||||
s.disableSSHAuth = *cfg.DisableSSHAuth
|
||||
}
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
|
||||
}
|
||||
|
||||
if s.showAdvancedSettings {
|
||||
s.iMngURL.SetText(s.managementURL)
|
||||
s.iPreSharedKey.SetText(cfg.PreSharedKey)
|
||||
@@ -1143,6 +1304,24 @@ func (s *serviceClient) getSrvConfig() {
|
||||
s.sDisableClientRoutes.SetChecked(cfg.DisableClientRoutes)
|
||||
s.sDisableServerRoutes.SetChecked(cfg.DisableServerRoutes)
|
||||
s.sBlockLANAccess.SetChecked(cfg.BlockLANAccess)
|
||||
if cfg.EnableSSHRoot != nil {
|
||||
s.sEnableSSHRoot.SetChecked(*cfg.EnableSSHRoot)
|
||||
}
|
||||
if cfg.EnableSSHSFTP != nil {
|
||||
s.sEnableSSHSFTP.SetChecked(*cfg.EnableSSHSFTP)
|
||||
}
|
||||
if cfg.EnableSSHLocalPortForwarding != nil {
|
||||
s.sEnableSSHLocalPortForward.SetChecked(*cfg.EnableSSHLocalPortForwarding)
|
||||
}
|
||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||
s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
|
||||
}
|
||||
if cfg.DisableSSHAuth != nil {
|
||||
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
|
||||
}
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
|
||||
}
|
||||
}
|
||||
|
||||
if s.mNotifications == nil {
|
||||
@@ -1213,6 +1392,15 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
config.DisableServerRoutes = cfg.DisableServerRoutes
|
||||
config.BlockLANAccess = cfg.BlockLanAccess
|
||||
|
||||
config.EnableSSHRoot = &cfg.EnableSSHRoot
|
||||
config.EnableSSHSFTP = &cfg.EnableSSHSFTP
|
||||
config.EnableSSHLocalPortForwarding = &cfg.EnableSSHLocalPortForwarding
|
||||
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
|
||||
config.DisableSSHAuth = &cfg.DisableSSHAuth
|
||||
|
||||
ttl := int(cfg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
|
||||
return &config
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user