mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 22:53:53 -04:00
Compare commits
91 Commits
feature/ch
...
move-licen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24b66fb406 | ||
|
|
9378b6b0a3 | ||
|
|
3779a3385f | ||
|
|
b5d75ad9c4 | ||
|
|
8db91abfdf | ||
|
|
6f817cad6d | ||
|
|
e3bb8c1b7b | ||
|
|
107066fa3d | ||
|
|
a7a85d4dc8 | ||
|
|
576b4a779c | ||
|
|
e6854dfd99 | ||
|
|
6f14134988 | ||
|
|
4fd64379da | ||
|
|
c20202a6c3 | ||
|
|
4386a21956 | ||
|
|
5882daf5d9 | ||
|
|
11d71e6e22 | ||
|
|
4dadcfd9bd | ||
|
|
34b55c600e | ||
|
|
316c0afa9a | ||
|
|
cf97799db8 | ||
|
|
4d297205c3 | ||
|
|
559f6aeeaf | ||
|
|
7216c201da | ||
|
|
4d89d0f115 | ||
|
|
610c880ec9 | ||
|
|
19adcb5f63 | ||
|
|
f3d31698da | ||
|
|
d9efe4e944 | ||
|
|
7e0bbaaa3c | ||
|
|
b3c7b3c7b2 | ||
|
|
66483ab48d | ||
|
|
5272fc2b18 | ||
|
|
4c53372815 | ||
|
|
79d28b71ee | ||
|
|
77a352763d | ||
|
|
cdd5c6c005 | ||
|
|
b1a9242c98 | ||
|
|
b43ef4f17b | ||
|
|
758a97c352 | ||
|
|
d93b7c2f38 | ||
|
|
fa893aa0a4 | ||
|
|
ac7120871b | ||
|
|
9a7daa132e | ||
|
|
cdded8c22e | ||
|
|
e4e0b8fff9 | ||
|
|
a4b067553d | ||
|
|
088956645f | ||
|
|
aa30b7afe8 | ||
|
|
f1bb4d2ac3 | ||
|
|
982841e25b | ||
|
|
a476b8d12f | ||
|
|
a21f924b26 | ||
|
|
9e51d2e8fb | ||
|
|
3e490d974c | ||
|
|
04bb314426 | ||
|
|
6e15882c11 | ||
|
|
76f9e11b29 | ||
|
|
612de2c784 | ||
|
|
1fdde66c31 | ||
|
|
5970591d24 | ||
|
|
0d5408baec | ||
|
|
96084e3a02 | ||
|
|
4bbca28eb6 | ||
|
|
279b77dee0 | ||
|
|
9d1554f9f7 | ||
|
|
f56075ca15 | ||
|
|
6ed846ae29 | ||
|
|
520f2cfdb4 | ||
|
|
0f79a8942d | ||
|
|
5299e9fda3 | ||
|
|
11bdf5b3a5 | ||
|
|
5fc95d4a0c | ||
|
|
c7884039b8 | ||
|
|
26fc32f1be | ||
|
|
a79cb1c11b | ||
|
|
306d75fe1a | ||
|
|
9468e69c8c | ||
|
|
f51ce7cee5 | ||
|
|
d47c6b624e | ||
|
|
471f90e8db | ||
|
|
1a3b04d2fe | ||
|
|
51b9e93eb9 | ||
|
|
2952669e97 | ||
|
|
7cd44a9a3c | ||
|
|
8684981b57 | ||
|
|
8e94d85d14 | ||
|
|
631b77dc3c | ||
|
|
50ac3d437e | ||
|
|
49bbd90557 | ||
|
|
bb74e903cd |
15
.github/workflows/check-license-dependencies.yml
vendored
15
.github/workflows/check-license-dependencies.yml
vendored
@@ -15,27 +15,28 @@ jobs:
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
||||
echo ""
|
||||
|
||||
# Find all directories except the problematic ones and system dirs
|
||||
FOUND_ISSUES=0
|
||||
find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort | while read dir; do
|
||||
while IFS= read -r dir; do
|
||||
echo "=== Checking $dir ==="
|
||||
# Search for problematic imports, excluding test files
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ ! -z "$RESULTS" ]; then
|
||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
||||
if [ -n "$RESULTS" ]; then
|
||||
echo "❌ Found problematic dependencies:"
|
||||
echo "$RESULTS"
|
||||
FOUND_ISSUES=1
|
||||
else
|
||||
echo "✓ No problematic dependencies found"
|
||||
fi
|
||||
done
|
||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
||||
|
||||
echo ""
|
||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||
echo ""
|
||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
||||
echo "These packages will change license and should not be imported by client or shared code"
|
||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||
exit 1
|
||||
else
|
||||
echo ""
|
||||
echo "✅ All license dependencies are clean"
|
||||
fi
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
|
||||
@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -201,6 +201,94 @@ func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetEnableSSHRoot reads SSH root login setting from config file
|
||||
func (p *Preferences) GetEnableSSHRoot() (bool, error) {
|
||||
if p.configInput.EnableSSHRoot != nil {
|
||||
return *p.configInput.EnableSSHRoot, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRoot == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRoot, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRoot stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRoot(enabled bool) {
|
||||
p.configInput.EnableSSHRoot = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHSFTP reads SSH SFTP setting from config file
|
||||
func (p *Preferences) GetEnableSSHSFTP() (bool, error) {
|
||||
if p.configInput.EnableSSHSFTP != nil {
|
||||
return *p.configInput.EnableSSHSFTP, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHSFTP == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHSFTP, err
|
||||
}
|
||||
|
||||
// SetEnableSSHSFTP stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHSFTP(enabled bool) {
|
||||
p.configInput.EnableSSHSFTP = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHLocalPortForwarding reads SSH local port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHLocalPortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHLocalPortForwarding != nil {
|
||||
return *p.configInput.EnableSSHLocalPortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHLocalPortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHLocalPortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHLocalPortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHLocalPortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHLocalPortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetEnableSSHRemotePortForwarding reads SSH remote port forwarding setting from config file
|
||||
func (p *Preferences) GetEnableSSHRemotePortForwarding() (bool, error) {
|
||||
if p.configInput.EnableSSHRemotePortForwarding != nil {
|
||||
return *p.configInput.EnableSSHRemotePortForwarding, nil
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.EnableSSHRemotePortForwarding == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.EnableSSHRemotePortForwarding, err
|
||||
}
|
||||
|
||||
// SetEnableSSHRemotePortForwarding stores the given value and waits for commit
|
||||
func (p *Preferences) SetEnableSSHRemotePortForwarding(enabled bool) {
|
||||
p.configInput.EnableSSHRemotePortForwarding = &enabled
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
|
||||
@@ -106,13 +106,6 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -248,7 +241,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -276,7 +269,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -293,7 +286,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -322,17 +315,8 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ const (
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||
@@ -64,7 +63,6 @@ var (
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
@@ -176,7 +174,6 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
|
||||
|
||||
@@ -259,7 +259,6 @@ func isServiceRunning() (bool, error) {
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
@@ -274,16 +273,12 @@ ManageForeignRoutingPolicyRules=no
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
parentDir := filepath.Dir(networkdConfDir)
|
||||
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
|
||||
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
|
||||
@@ -3,125 +3,757 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
userName = "root"
|
||||
host string
|
||||
const (
|
||||
sshUsernameDesc = "SSH username"
|
||||
hostArgumentRequired = "host argument required"
|
||||
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
enableSSHRootFlag = "enable-ssh-root"
|
||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
}
|
||||
var (
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
localForwards []string
|
||||
remoteForwards []string
|
||||
strictHostKeyChecking bool
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
)
|
||||
|
||||
split := strings.Split(args[0], "@")
|
||||
if len(split) == 2 {
|
||||
userName = split[0]
|
||||
host = split[1]
|
||||
} else {
|
||||
host = args[0]
|
||||
}
|
||||
var (
|
||||
serverSSHAllowed bool
|
||||
enableSSHRoot bool
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
)
|
||||
|
||||
return nil
|
||||
},
|
||||
Short: "Connect to a remote SSH server",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
err := util.InitLog(logLevel, util.LogConsole)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
|
||||
if !util.IsAdmin() {
|
||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sm := profilemanager.NewServiceManager(configPath)
|
||||
activeProf, err := sm.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
profPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile path: %v", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.ReadConfig(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read profile config: %v", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
sshCmd.AddCommand(sshSftpCmd)
|
||||
sshCmd.AddCommand(sshProxyCmd)
|
||||
sshCmd.AddCommand(sshDetectCmd)
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
return
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [flags] [user@]host [command]",
|
||||
Short: "Connect to a NetBird peer via SSH",
|
||||
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||
|
||||
Port Forwarding:
|
||||
-L [bind_address:]port:host:hostport Local port forwarding
|
||||
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||
|
||||
SSH Options:
|
||||
-p, --port int Remote SSH port (default 22)
|
||||
-u, --user string SSH username
|
||||
--login string SSH username (alias for --user)
|
||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||
-o, --known-hosts string Path to known_hosts file
|
||||
-i, --identity string Path to SSH private key file
|
||||
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh root@peer-hostname
|
||||
netbird ssh --login root peer-hostname
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||
DisableFlagParsing: true,
|
||||
Args: validateSSHArgsWithoutFlagParsing,
|
||||
RunE: sshFn,
|
||||
Aliases: []string{"ssh"},
|
||||
}
|
||||
|
||||
func sshFn(cmd *cobra.Command, args []string) error {
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
}
|
||||
}
|
||||
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
err = c.OpenTerminal()
|
||||
if err != nil {
|
||||
return err
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
||||
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||
var localForwardFlags []string
|
||||
var remoteForwardFlags []string
|
||||
var filteredArgs []string
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
switch {
|
||||
case strings.HasPrefix(arg, "-L"):
|
||||
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
|
||||
case strings.HasPrefix(arg, "-R"):
|
||||
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
|
||||
default:
|
||||
filteredArgs = append(filteredArgs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||
}
|
||||
|
||||
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
|
||||
if arg == "-L" || arg == "-R" {
|
||||
if i+1 < len(args) {
|
||||
flags = append(flags, args[i+1])
|
||||
i++
|
||||
}
|
||||
} else if len(arg) > 2 {
|
||||
flags = append(flags, arg[2:])
|
||||
}
|
||||
return flags, i
|
||||
}
|
||||
|
||||
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||
func extractGlobalFlags(args []string) {
|
||||
sshPos := findSSHCommandPosition(args)
|
||||
if sshPos == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
globalArgs := args[:sshPos]
|
||||
parseGlobalArgs(globalArgs)
|
||||
}
|
||||
|
||||
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||
func findSSHCommandPosition(args []string) int {
|
||||
for i, arg := range args {
|
||||
if arg == "ssh" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
const (
|
||||
configFlag = "config"
|
||||
logLevelFlag = "log-level"
|
||||
logFileFlag = "log-file"
|
||||
)
|
||||
|
||||
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||
func parseGlobalArgs(globalArgs []string) {
|
||||
flagHandlers := map[string]func(string){
|
||||
configFlag: func(value string) { configPath = value },
|
||||
logLevelFlag: func(value string) { logLevel = value },
|
||||
logFileFlag: func(value string) {
|
||||
if !slices.Contains(logFiles, value) {
|
||||
logFiles = append(logFiles, value)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
shortFlags := map[string]string{
|
||||
"c": configFlag,
|
||||
"l": logLevelFlag,
|
||||
}
|
||||
|
||||
for i := 0; i < len(globalArgs); i++ {
|
||||
arg := globalArgs[i]
|
||||
|
||||
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||
i = nextIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseFlag handles generic flag parsing for both long and short forms
|
||||
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex
|
||||
}
|
||||
|
||||
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex + 1
|
||||
}
|
||||
|
||||
return false, currentIndex
|
||||
}
|
||||
|
||||
type parsedFlag struct {
|
||||
flagName string
|
||||
value string
|
||||
}
|
||||
|
||||
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if !strings.Contains(arg, "=") {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(arg, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "--") {
|
||||
flagName := strings.TrimPrefix(parts[0], "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// parseSpacedFormat handles --flag value and -f value formats
|
||||
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if currentIndex+1 >= len(args) {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
flagName := strings.TrimPrefix(arg, "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||
shortFlag := strings.TrimPrefix(arg, "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
// sshFlags contains all SSH-related flags and parameters
|
||||
type sshFlags struct {
|
||||
Port int
|
||||
Username string
|
||||
Login string
|
||||
StrictHostKeyChecking bool
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
RemoteForwards []string
|
||||
Host string
|
||||
Command string
|
||||
}
|
||||
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
|
||||
flags := &sshFlags{}
|
||||
|
||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||
fs.String("user", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
|
||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||
fs.String("known-hosts", "", "Path to known_hosts file")
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.String("identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||
|
||||
return fs, flags
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
resetSSHGlobals()
|
||||
|
||||
if len(os.Args) > 2 {
|
||||
extractGlobalFlags(os.Args[1:])
|
||||
}
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, flags := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
return parseHostnameAndCommand(filteredArgs)
|
||||
}
|
||||
|
||||
remaining := fs.Args()
|
||||
if len(remaining) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
port = flags.Port
|
||||
if flags.Username != "" {
|
||||
username = flags.Username
|
||||
} else if flags.Login != "" {
|
||||
username = flags.Login
|
||||
}
|
||||
|
||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
}
|
||||
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = flags.LogLevel
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
remoteForwards = remoteForwardFlags
|
||||
|
||||
return parseHostnameAndCommand(remaining)
|
||||
}
|
||||
|
||||
func parseHostnameAndCommand(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
arg := args[0]
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
if username == "" {
|
||||
username = parts[0]
|
||||
}
|
||||
host = parts[1]
|
||||
} else {
|
||||
host = arg
|
||||
}
|
||||
|
||||
if username == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
username = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
username = currentUser.Username
|
||||
} else {
|
||||
username = "root"
|
||||
}
|
||||
}
|
||||
|
||||
// Everything after hostname becomes the command
|
||||
if len(args) > 1 {
|
||||
command = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
|
||||
sshCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-sshCtx.Done()
|
||||
if err := c.Close(); err != nil {
|
||||
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||
return fmt.Errorf("start port forwarding: %w", err)
|
||||
}
|
||||
|
||||
if command != "" {
|
||||
return executeSSHCommand(sshCtx, c, command)
|
||||
}
|
||||
return openSSHTerminal(sshCtx, c)
|
||||
}
|
||||
|
||||
// executeSSHCommand executes a command over SSH.
|
||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openSSHTerminal opens an interactive SSH terminal.
|
||||
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("open terminal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||
for _, forward := range localForwards {
|
||||
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, forward := range remoteForwards {
|
||||
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Local port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Remote port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
// Support formats:
|
||||
// port:host:hostport -> localhost:port -> host:hostport
|
||||
// host:port:host:hostport -> host:port -> host:hostport
|
||||
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||
|
||||
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||
return parseIPv6ForwardSpec(spec)
|
||||
}
|
||||
|
||||
parts := strings.Split(spec, ":")
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||
}
|
||||
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
return parseTwoPartForwardSpec(parts, spec)
|
||||
case 3:
|
||||
return parseThreePartForwardSpec(parts)
|
||||
case 4:
|
||||
return parseFourPartForwardSpec(parts)
|
||||
default:
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||
}
|
||||
}
|
||||
|
||||
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||
if isUnixSocket(parts[1]) {
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||
}
|
||||
|
||||
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||
if isUnixSocket(parts[2]) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2] + ":" + parts[3]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||
idx := strings.Index(spec, "]:")
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||
}
|
||||
|
||||
ipv6Host := spec[:idx+1]
|
||||
remaining := spec[idx+2:]
|
||||
|
||||
parts := strings.Split(remaining, ":")
|
||||
if len(parts) != 3 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||
}
|
||||
|
||||
localAddr := ipv6Host + ":" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// isUnixSocket checks if a path is a Unix socket path.
|
||||
func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
var sshProxyCmd = &cobra.Command{
|
||||
Use: "proxy <host> <port>",
|
||||
Short: "Internal SSH proxy for native SSH client integration",
|
||||
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshProxyFn,
|
||||
}
|
||||
|
||||
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
logOutput := "console"
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshDetectCmd = &cobra.Command{
|
||||
Use: "detect <host> <port>",
|
||||
Short: "Detect if a host is running NetBird SSH",
|
||||
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshDetectFn,
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sshExecUID uint32
|
||||
sshExecGID uint32
|
||||
sshExecGroups []uint
|
||||
sshExecWorkingDir string
|
||||
sshExecShell string
|
||||
sshExecCommand string
|
||||
sshExecPTY bool
|
||||
)
|
||||
|
||||
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||
var sshExecCmd = &cobra.Command{
|
||||
Use: "exec",
|
||||
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||
Hidden: true,
|
||||
RunE: runSSHExec,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||
|
||||
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sshCmd.AddCommand(sshExecCmd)
|
||||
}
|
||||
|
||||
// runSSHExec handles the SSH exec subcommand execution.
|
||||
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sshExecGroups {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sshExecUID,
|
||||
GID: sshExecGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sshExecWorkingDir,
|
||||
Shell: sshExecShell,
|
||||
Command: sshExecCommand,
|
||||
PTY: sshExecPTY,
|
||||
}
|
||||
|
||||
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpUID uint32
|
||||
sftpGID uint32
|
||||
sftpGroupsInt []uint
|
||||
sftpWorkingDir string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with privilege dropping (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sftpGroupsInt {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sftpUID,
|
||||
GID: sftpGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sftpWorkingDir,
|
||||
Shell: "",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Tracef("starting SFTP server with dropped privileges")
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
if closeErr := sftpServer.Close(); closeErr != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
|
||||
}
|
||||
os.Exit(sshserver.ExitCodeSuccess)
|
||||
return nil
|
||||
}
|
||||
93
client/cmd/ssh_sftp_windows.go
Normal file
93
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,93 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/user"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpWorkingDir string
|
||||
windowsUsername string
|
||||
windowsDomain string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with user switching for Windows (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
return sftpMainDirect(cmd)
|
||||
}
|
||||
|
||||
func sftpMainDirect(cmd *cobra.Command) error {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
|
||||
if windowsUsername != "" {
|
||||
expectedUsername := windowsUsername
|
||||
if windowsDomain != "" {
|
||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||
}
|
||||
if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
|
||||
cmd.PrintErrf("user switching failed\n")
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||
|
||||
if sftpWorkingDir != "" {
|
||||
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
log.Debugf("starting SFTP server")
|
||||
exitCode := sshserver.ExitCodeSuccess
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
exitCode = sshserver.ExitCodeShellExecFail
|
||||
}
|
||||
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("SFTP server close error: %v", err)
|
||||
}
|
||||
|
||||
os.Exit(exitCode)
|
||||
return nil
|
||||
}
|
||||
669
client/cmd/ssh_test.go
Normal file
669
client/cmd/ssh_test.go
Normal file
@@ -0,0 +1,669 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedUser string
|
||||
expectedPort int
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "basic host",
|
||||
args: []string{"hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "user@host format",
|
||||
args: []string{"user@hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "user",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
name: "host with command",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "echo hello",
|
||||
},
|
||||
{
|
||||
name: "command with flags should be preserved",
|
||||
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "ls -la /tmp",
|
||||
},
|
||||
{
|
||||
name: "double dash separator",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22,
|
||||
expectedCmd: "-- ls -la",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
// Mock command for testing
|
||||
cmd := sshCmd
|
||||
cmd.SetArgs(tt.args)
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
if tt.expectedUser != "" {
|
||||
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||
}
|
||||
assert.Equal(t, tt.expectedPort, port, "port mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||
// Test that SSH flags don't conflict with command flags
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flags",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "ls flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "grep with -r flag",
|
||||
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
|
||||
expectedCmd: "grep -r pattern /path",
|
||||
description: "grep flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "ps with aux flags",
|
||||
args: []string{"hostname", "ps", "aux"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "ps flags should be passed to remote command",
|
||||
},
|
||||
{
|
||||
name: "command with double dash",
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedCmd: "-- ls -la",
|
||||
description: "double dash should be preserved in command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||
// Test that commands with arguments should execute the command and exit,
|
||||
// not drop to an interactive shell
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
shouldExit bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls command should execute and exit",
|
||||
args: []string{"hostname", "ls"},
|
||||
expectedCmd: "ls",
|
||||
shouldExit: true,
|
||||
description: "ls command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "ls with flags should execute and exit",
|
||||
args: []string{"hostname", "ls", "-la"},
|
||||
expectedCmd: "ls -la",
|
||||
shouldExit: true,
|
||||
description: "ls with flags should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "pwd command should execute and exit",
|
||||
args: []string{"hostname", "pwd"},
|
||||
expectedCmd: "pwd",
|
||||
shouldExit: true,
|
||||
description: "pwd command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "echo command should execute and exit",
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedCmd: "echo hello",
|
||||
shouldExit: true,
|
||||
description: "echo command should execute and exit, not drop to shell",
|
||||
},
|
||||
{
|
||||
name: "no command should open shell",
|
||||
args: []string{"hostname"},
|
||||
expectedCmd: "",
|
||||
shouldExit: false,
|
||||
description: "no command should open interactive shell",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// When command is present, it should execute the command and exit
|
||||
// When command is empty, it should open interactive shell
|
||||
hasCommand := command != ""
|
||||
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||
// Test that flags after hostname are not parsed by netbird but passed to SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ls with -la flag should not be parsed by netbird",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
|
||||
},
|
||||
{
|
||||
name: "command with netbird-like flags should be passed through",
|
||||
args: []string{"hostname", "echo", "--help"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "echo --help",
|
||||
expectError: false,
|
||||
description: "--help should be passed to echo, not parsed by netbird",
|
||||
},
|
||||
{
|
||||
name: "command with -p flag should not conflict with SSH port flag",
|
||||
args: []string{"hostname", "ps", "-p", "1234"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ps -p 1234",
|
||||
expectError: false,
|
||||
description: "ps -p should be passed to ps command, not parsed as port",
|
||||
},
|
||||
{
|
||||
name: "tar with flags should be passed through",
|
||||
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "tar -czf backup.tar.gz /home",
|
||||
expectError: false,
|
||||
description: "tar flags should be passed to tar command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
|
||||
// should not parse -la as netbird flags but pass them to the SSH command
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedCmd string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "original issue: ls -la should be preserved",
|
||||
args: []string{"debian2", "ls", "-la"},
|
||||
expectedHost: "debian2",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "The original failing case should now work",
|
||||
},
|
||||
{
|
||||
name: "ls -l should be preserved",
|
||||
args: []string{"hostname", "ls", "-l"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -l",
|
||||
expectError: false,
|
||||
description: "Single letter flags should be preserved",
|
||||
},
|
||||
{
|
||||
name: "SSH port flag should work",
|
||||
args: []string{"-p", "2222", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedCmd: "ls -la",
|
||||
expectError: false,
|
||||
description: "SSH -p flag should be parsed, command flags preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
// Check port for the test case with -p flag
|
||||
if len(tt.args) > 0 && tt.args[0] == "-p" {
|
||||
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local port forwarding -L",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Single -L flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "remote port forwarding -R",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectError: false,
|
||||
description: "Single -R flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple local port forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Multiple -L flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple remote port forwards",
|
||||
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Multiple -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed local and remote forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Mixed -L and -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with bind address",
|
||||
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with command should work",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec string
|
||||
expectedLocal string
|
||||
expectedRemote string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "simple port forward",
|
||||
spec: "8080:localhost:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Simple port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with bind address",
|
||||
spec: "127.0.0.1:8080:localhost:80",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "bind_address:port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward to different host",
|
||||
spec: "8080:example.com:443",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "example.com:443",
|
||||
expectError: false,
|
||||
description: "Forwarding to different host should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with IPv6 (needs bracket support)",
|
||||
spec: "::1:8080:localhost:80",
|
||||
expectError: true,
|
||||
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too few parts",
|
||||
spec: "8080:localhost",
|
||||
expectError: true,
|
||||
description: "Invalid format with too few parts should fail",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too many parts",
|
||||
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||
expectError: true,
|
||||
description: "Invalid format with too many parts should fail",
|
||||
},
|
||||
{
|
||||
name: "empty spec",
|
||||
spec: "",
|
||||
expectError: true,
|
||||
description: "Empty spec should fail",
|
||||
},
|
||||
{
|
||||
name: "unix socket local forward",
|
||||
spec: "8080:/tmp/socket",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket forwarding should work",
|
||||
},
|
||||
{
|
||||
name: "unix socket with bind address",
|
||||
spec: "127.0.0.1:8080:/tmp/socket",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
spec: "8080:*:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "*:80",
|
||||
expectError: false,
|
||||
description: "Wildcard in remote host should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||
// Integration test for port forwarding with the actual SSH command implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local forward with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectedCmd: "echo test",
|
||||
description: "Local forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "remote forward with command",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "Remote forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "multiple forwards with user and command",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "Complex case with multiple forwards, user, and command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
}{
|
||||
{
|
||||
name: "cmd flag passed as command",
|
||||
args: []string{"hostname", "--cmd", "echo test"},
|
||||
expectedCmd: "--cmd echo test",
|
||||
},
|
||||
{
|
||||
name: "uid flag passed as command",
|
||||
args: []string{"hostname", "--uid", "1000"},
|
||||
expectedCmd: "--uid 1000",
|
||||
},
|
||||
{
|
||||
name: "shell flag passed as command",
|
||||
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||
expectedCmd: "--shell /bin/bash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "hostname", host)
|
||||
assert.Equal(t, tt.expectedCmd, command)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,9 +12,6 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||
client "github.com/netbirdio/netbird/client/server"
|
||||
@@ -87,6 +84,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
@@ -112,18 +110,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -286,13 +286,6 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
@@ -355,6 +348,21 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForward = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
@@ -439,6 +447,26 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -539,6 +567,26 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
@@ -18,12 +18,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
var (
|
||||
ErrClientAlreadyStarted = errors.New("client already started")
|
||||
ErrClientNotStarted = errors.New("client not started")
|
||||
ErrEngineNotStarted = errors.New("engine not started")
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return c.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return sshcommon.ErrPeerNotFound
|
||||
}
|
||||
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect == nil {
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
return nil, ErrEngineNotStarted
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
|
||||
@@ -35,6 +35,12 @@ const (
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
|
||||
@@ -59,12 +65,6 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -1114,3 +1115,138 @@ func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dst
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestShouldForward(t *testing.T) {
|
||||
// Set up test addresses
|
||||
wgIP := netip.MustParseAddr("100.10.0.1")
|
||||
otherIP := netip.MustParseAddr("100.10.0.2")
|
||||
|
||||
// Create test manager with mock interface
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
// Set the mock to return our test WG IP
|
||||
ifaceMock.AddressFunc = func() wgaddr.Address {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Helper to create decoder with TCP packet
|
||||
createTCPDecoder := func(dstPort uint16) *decoder {
|
||||
ipv4 := &layers.IPv4{
|
||||
Version: 4,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: net.ParseIP("192.168.1.100"),
|
||||
DstIP: wgIP.AsSlice(),
|
||||
}
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: 54321,
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
err := tcp.SetNetworkLayerForChecksum(ipv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
d := &decoder{
|
||||
decoded: []gopacket.LayerType{},
|
||||
}
|
||||
d.parser = gopacket.NewDecodingLayerParser(
|
||||
layers.LayerTypeIPv4,
|
||||
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||
)
|
||||
d.parser.IgnoreUnsupported = true
|
||||
|
||||
err = d.parser.DecodeLayers(buf.Bytes(), &d.decoded)
|
||||
require.NoError(t, err)
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
localForwarding bool
|
||||
netstack bool
|
||||
dstIP netip.Addr
|
||||
serviceRegistered bool
|
||||
servicePort uint16
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no local forwarding",
|
||||
localForwarding: false,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should never forward when local forwarding disabled",
|
||||
},
|
||||
{
|
||||
name: "traffic to other local interface",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: otherIP,
|
||||
expected: true,
|
||||
description: "should forward traffic to our other local interfaces (not NetBird IP)",
|
||||
},
|
||||
{
|
||||
name: "traffic to NetBird IP, no netstack",
|
||||
localForwarding: true,
|
||||
netstack: false,
|
||||
dstIP: wgIP,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners (final return false path)",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, no service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
expected: true,
|
||||
description: "should forward when in netstack mode with no matching service",
|
||||
},
|
||||
{
|
||||
name: "traffic to our IP, netstack mode, with service",
|
||||
localForwarding: true,
|
||||
netstack: true,
|
||||
dstIP: wgIP,
|
||||
serviceRegistered: true,
|
||||
servicePort: 22,
|
||||
expected: false,
|
||||
description: "should send to netstack listeners when service is registered",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Configure manager
|
||||
manager.localForwarding = tt.localForwarding
|
||||
manager.netstack = tt.netstack
|
||||
|
||||
// Register service if needed
|
||||
if tt.serviceRegistered {
|
||||
manager.RegisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
defer manager.UnregisterNetstackService(nftypes.TCP, tt.servicePort)
|
||||
}
|
||||
|
||||
// Create decoder for the test
|
||||
decoder := createTCPDecoder(tt.servicePort)
|
||||
if !tt.serviceRegistered {
|
||||
decoder = createTCPDecoder(8080) // Use non-registered port
|
||||
}
|
||||
|
||||
// Test the method
|
||||
result := manager.shouldForward(decoder, tt.dstIP)
|
||||
require.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
85
client/firewall/uspfilter/nat_stateful_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
// TestPortDNATBasic tests basic port DNAT functionality
|
||||
func TestPortDNATBasic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rule for peer B (the target)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Scenario: Peer A connects to Peer B on port 22 (should get NAT)
|
||||
packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d := parsePacket(t, packetAtoB)
|
||||
translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB)
|
||||
require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)")
|
||||
|
||||
// Verify port was translated to 22022
|
||||
d = parsePacket(t, packetAtoB)
|
||||
require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022")
|
||||
|
||||
// Scenario: Return traffic from Peer B to Peer A should NOT be translated
|
||||
// (prevents double NAT - original port stored in conntrack)
|
||||
returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321)
|
||||
d2 := parsePacket(t, returnPacket)
|
||||
translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA)
|
||||
require.False(t, translatedReturn, "Return traffic from same IP should not be translated")
|
||||
}
|
||||
|
||||
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||
func TestPortDNATMultipleRules(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
// Define peer IPs
|
||||
peerA := netip.MustParseAddr("100.10.0.50")
|
||||
peerB := netip.MustParseAddr("100.10.0.51")
|
||||
|
||||
// Add SSH port redirection rules for both peers
|
||||
err = manager.addPortRedirection(peerA, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test traffic to peer B gets translated
|
||||
packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22)
|
||||
d1 := parsePacket(t, packetToB)
|
||||
translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB)
|
||||
require.True(t, translatedToB, "Traffic to peer B should be translated")
|
||||
d1 = parsePacket(t, packetToB)
|
||||
require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022")
|
||||
|
||||
// Test traffic to peer A gets translated
|
||||
packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22)
|
||||
d2 := parsePacket(t, packetToA)
|
||||
translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA)
|
||||
require.True(t, translatedToA, "Traffic to peer A should be translated")
|
||||
d2 = parsePacket(t, packetToA)
|
||||
require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022")
|
||||
}
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
@@ -83,22 +82,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
|
||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: strconv.Itoa(ssh.DefaultSSHPort),
|
||||
})
|
||||
}
|
||||
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
// we have old version of management without rules handling, we should allow all traffic
|
||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
||||
|
||||
@@ -272,70 +272,3 @@ func TestPortInfoEmpty(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
SshConfig: &mgmProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
},
|
||||
},
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||
IP: network.Addr(),
|
||||
Network: network,
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
acl.ApplyFiltering(networkMap, false)
|
||||
|
||||
expectedRules := 3
|
||||
if fw.IsStateful() {
|
||||
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||
}
|
||||
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||
}
|
||||
|
||||
@@ -128,34 +128,9 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
if d.providerConfig.LoginHint != "" {
|
||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||
if deviceCode.VerificationURI != "" {
|
||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||
}
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func appendLoginHint(uri, loginHint string) string {
|
||||
if uri == "" || loginHint == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||
return uri
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("login_hint", loginHint)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
return parsedURL.String()
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
|
||||
@@ -66,34 +66,32 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
@@ -109,7 +107,5 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -109,9 +109,6 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
if p.providerConfig.LoginHint != "" {
|
||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
|
||||
@@ -416,20 +416,25 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
nm = *config.NetworkMonitor
|
||||
}
|
||||
engineConf := &EngineConfig{
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
WgIfaceName: config.WgIface,
|
||||
WgAddr: peerConfig.Address,
|
||||
IFaceBlackList: config.IFaceBlackList,
|
||||
DisableIPv6Discovery: config.DisableIPv6Discovery,
|
||||
WgPrivateKey: key,
|
||||
WgPort: config.WgPort,
|
||||
NetworkMonitor: nm,
|
||||
SSHKey: []byte(config.SSHKey),
|
||||
NATExternalIPs: config.NATExternalIPs,
|
||||
CustomDNSAddress: config.CustomDNSAddress,
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: config.DisableSSHAuth,
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||
@@ -515,6 +520,11 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -44,8 +44,6 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
@@ -186,20 +184,6 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||
|
||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||
|
||||
DNS Configuration
|
||||
The debug bundle includes platform-specific DNS configuration files:
|
||||
|
||||
resolv.conf (Unix systems):
|
||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||
- Includes nameserver entries, search domains, and resolver options
|
||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||
|
||||
scutil_dns.txt (macOS only):
|
||||
- Contains detailed DNS configuration from scutil --dns
|
||||
- Shows DNS configuration for all network interfaces
|
||||
- Includes search domains, nameservers, and DNS resolver settings
|
||||
- All IP addresses and domain names are anonymized
|
||||
`
|
||||
|
||||
const (
|
||||
@@ -373,10 +357,6 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
if err := g.addFirewallRules(); err != nil {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addReadme() error {
|
||||
@@ -453,6 +433,18 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.ServerSSHAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRoot != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRoot: %v\n", *g.internalConfig.EnableSSHRoot))
|
||||
}
|
||||
if g.internalConfig.EnableSSHSFTP != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHSFTP: %v\n", *g.internalConfig.EnableSSHSFTP))
|
||||
}
|
||||
if g.internalConfig.EnableSSHLocalPortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHLocalPortForwarding: %v\n", *g.internalConfig.EnableSSHLocalPortForwarding))
|
||||
}
|
||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addScutilDNS(); err != nil {
|
||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addScutilDNS() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(output)) == 0 {
|
||||
return fmt.Errorf("no scutil DNS output")
|
||||
}
|
||||
|
||||
content := string(output)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,7 +5,3 @@ package debug
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
//go:build unix && !darwin && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !unix
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
//go:build unix && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
func (g *BundleGenerator) addResolvConf() error {
|
||||
data, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -38,8 +38,6 @@ type DeviceAuthProviderConfig struct {
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
@@ -30,7 +29,6 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
@@ -51,10 +49,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -115,7 +113,12 @@ type EngineConfig struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerSSHAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
@@ -148,8 +151,6 @@ type Engine struct {
|
||||
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
// sshMux protects sshServer field access
|
||||
sshMux sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
@@ -175,8 +176,7 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
|
||||
sshServer nbssh.Server
|
||||
sshServer sshServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -246,7 +246,6 @@ func NewEngine(
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
sshServerFunc: nbssh.DefaultSSHServer,
|
||||
statusRecorder: statusRecorder,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
@@ -268,6 +267,7 @@ func NewEngine(
|
||||
path = mobileDep.StateFilePath
|
||||
}
|
||||
engine.stateManager = statemanager.New(path)
|
||||
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
return engine
|
||||
@@ -292,6 +292,12 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
log.Info("Network monitor: stopped")
|
||||
|
||||
if err := e.stopSSHServer(); err != nil {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||
e.stopDNSServer()
|
||||
|
||||
@@ -703,16 +709,10 @@ func (e *Engine) removeAllPeers() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
||||
// removePeer closes an existing peer connection and removes a peer
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
err := e.statusRecorder.RemovePeer(peerKey)
|
||||
@@ -884,6 +884,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
@@ -893,74 +898,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isNil(server nbssh.Server) bool {
|
||||
return server == nil || reflect.ValueOf(server).IsNil()
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Infof("SSH server is disabled because inbound connections are blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is not enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if sshConf.GetSshEnabled() {
|
||||
if runtime.GOOS == "windows" {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
if nbnetstack.IsEnabled() {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
if err != nil {
|
||||
e.sshMux.Unlock()
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
e.sshMux.Unlock()
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
e.sshMux.Unlock()
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else {
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
@@ -973,8 +910,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
err := e.updateSSH(conf.GetSshConfig())
|
||||
if err != nil {
|
||||
if err := e.updateSSH(conf.GetSshConfig()); err != nil {
|
||||
log.Warnf("failed handling SSH server setup: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1012,6 +948,11 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
@@ -1170,19 +1111,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
// update SSHServer by adding remote peer SSH keys
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
for _, config := range networkMap.GetRemotePeers() {
|
||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
||||
if err != nil {
|
||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||
|
||||
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1544,15 +1477,6 @@ func (e *Engine) close() {
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
if err != nil {
|
||||
@@ -1583,6 +1507,11 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||
|
||||
338
client/internal/engine_ssh.go
Normal file
338
client/internal/engine_ssh.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type sshServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort) error
|
||||
Stop() error
|
||||
}
|
||||
|
||||
func (e *Engine) setupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("add SSH port redirection: %w", err)
|
||||
}
|
||||
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
if e.config.BlockInbound {
|
||||
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is disabled in config")
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if !sshConf.GetSshEnabled() {
|
||||
if e.config.ServerSSHAllowed {
|
||||
log.Info("SSH server is locally allowed but disabled by management server")
|
||||
}
|
||||
return e.stopSSHServer()
|
||||
}
|
||||
|
||||
if e.sshServer != nil {
|
||||
log.Debug("SSH server is already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||
return e.startSSHServer(nil)
|
||||
}
|
||||
|
||||
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||
jwtConfig := &sshserver.JWTConfig{
|
||||
Issuer: protoJWT.GetIssuer(),
|
||||
Audience: protoJWT.GetAudience(),
|
||||
KeysLocation: protoJWT.GetKeysLocation(),
|
||||
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||
}
|
||||
|
||||
return e.startSSHServer(jwtConfig)
|
||||
}
|
||||
|
||||
return errors.New("SSH server requires valid JWT configuration")
|
||||
}
|
||||
|
||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||
if len(peerInfo) == 0 {
|
||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||
return nil
|
||||
}
|
||||
|
||||
configMgr := sshconfig.New()
|
||||
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
return nil // Don't fail engine startup on SSH config issues
|
||||
}
|
||||
|
||||
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||
|
||||
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||
}); err != nil {
|
||||
log.Warnf("failed to update SSH config state: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||
var peerInfo []sshconfig.PeerSSHInfo
|
||||
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
peerIP := e.extractPeerIP(peerConfig)
|
||||
hostname := e.extractHostname(peerConfig)
|
||||
|
||||
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||
Hostname: hostname,
|
||||
IP: peerIP,
|
||||
FQDN: peerConfig.GetFqdn(),
|
||||
})
|
||||
}
|
||||
|
||||
return peerInfo
|
||||
}
|
||||
|
||||
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
if len(peerConfig.GetAllowedIps()) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
|
||||
return prefix.Addr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractHostname extracts short hostname from FQDN
|
||||
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
||||
fqdn := peerConfig.GetFqdn()
|
||||
if fqdn == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(fqdn, ".")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||
for _, peerConfig := range remotePeers {
|
||||
if peerConfig.GetSshConfig() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||
if len(sshPubKeyBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
|
||||
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||
}
|
||||
|
||||
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||
e.syncMsgMux.Lock()
|
||||
statusRecorder := e.statusRecorder
|
||||
e.syncMsgMux.Unlock()
|
||||
|
||||
if statusRecorder == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
if len(peerState.SSHHostKey) > 0 {
|
||||
return peerState.SSHHostKey, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||
func (e *Engine) cleanupSSHConfig() {
|
||||
configMgr := sshconfig.New()
|
||||
|
||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||
log.Warnf("failed to remove SSH client config: %v", err)
|
||||
} else {
|
||||
log.Debugf("SSH client config cleanup completed")
|
||||
}
|
||||
}
|
||||
|
||||
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
e.sshServer = server
|
||||
|
||||
if err := e.setupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
return fmt.Errorf("start SSH server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// configureSSHServer applies SSH configuration options to the server.
|
||||
func (e *Engine) configureSSHServer(server *sshserver.Server) {
|
||||
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
|
||||
server.SetAllowRootLogin(true)
|
||||
log.Info("SSH root login enabled")
|
||||
} else {
|
||||
server.SetAllowRootLogin(false)
|
||||
log.Info("SSH root login disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
|
||||
server.SetAllowSFTP(true)
|
||||
log.Info("SSH SFTP subsystem enabled")
|
||||
} else {
|
||||
server.SetAllowSFTP(false)
|
||||
log.Info("SSH SFTP subsystem disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
log.Info("SSH local port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowLocalPortForwarding(false)
|
||||
log.Info("SSH local port forwarding disabled (default)")
|
||||
}
|
||||
|
||||
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
|
||||
server.SetAllowRemotePortForwarding(true)
|
||||
log.Info("SSH remote port forwarding enabled")
|
||||
} else {
|
||||
server.SetAllowRemotePortForwarding(false)
|
||||
log.Info("SSH remote port forwarding disabled (default)")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupSSHPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
|
||||
return fmt.Errorf("remove SSH port redirection: %w", err)
|
||||
}
|
||||
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) stopSSHServer() error {
|
||||
if e.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupSSHPortRedirection(); err != nil {
|
||||
log.Warnf("failed to cleanup SSH port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
|
||||
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping SSH server")
|
||||
err := e.sshServer.Stop()
|
||||
e.sshServer = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -26,9 +26,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -46,7 +43,7 @@ import (
|
||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
@@ -214,11 +211,13 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestEngine_SSH(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping TestEngine_SSH")
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -240,6 +239,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
WgPort: 33100,
|
||||
ServerSSHAllowed: true,
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
@@ -250,35 +250,8 @@ func TestEngine_SSH(t *testing.T) {
|
||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||
}
|
||||
|
||||
var sshKeysAdded []string
|
||||
var sshPeersRemoved []string
|
||||
|
||||
sshCtx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
|
||||
return &ssh.MockServer{
|
||||
Ctx: sshCtx,
|
||||
StopFunc: func() error {
|
||||
cancel()
|
||||
return nil
|
||||
},
|
||||
StartFunc: func() error {
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
},
|
||||
AddAuthorizedKeyFunc: func(peer, newKey string) error {
|
||||
sshKeysAdded = append(sshKeysAdded, newKey)
|
||||
return nil
|
||||
},
|
||||
RemoveAuthorizedKeyFunc: func(peer string) {
|
||||
sshPeersRemoved = append(sshPeersRemoved, peer)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
err = engine.Start(nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := engine.Stop()
|
||||
@@ -304,9 +277,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
@@ -314,19 +285,24 @@ func TestEngine_SSH(t *testing.T) {
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
|
||||
SshConfig: &mgmtProto.SSHConfig{
|
||||
SshEnabled: true,
|
||||
JwtConfig: &mgmtProto.JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
MaxTokenAge: 3600,
|
||||
},
|
||||
}},
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
|
||||
|
||||
// now remove peer
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
@@ -336,13 +312,10 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
assert.NotNil(t, engine.sshServer)
|
||||
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
|
||||
|
||||
// now disable SSH server
|
||||
networkMap = &mgmtProto.NetworkMap{
|
||||
@@ -354,12 +327,70 @@ func TestEngine_SSH(t *testing.T) {
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHUpdateLogic(t *testing.T) {
|
||||
// Test that SSH server start/stop logic works based on config
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: false, // Start with SSH disabled
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Test SSH disabled config
|
||||
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
|
||||
err := engine.updateSSH(sshConfig)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
// Test inbound blocked
|
||||
engine.config.BlockInbound = true
|
||||
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
engine.config.BlockInbound = false
|
||||
|
||||
// Test with server SSH not allowed
|
||||
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
}
|
||||
|
||||
func TestEngine_SSHServerConsistency(t *testing.T) {
|
||||
|
||||
t.Run("server set only on successful creation", func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: true,
|
||||
SSHKey: []byte("test-key"),
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
engine.wgInterface = nil
|
||||
|
||||
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
})
|
||||
|
||||
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{
|
||||
ServerSSHAllowed: false,
|
||||
},
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
err := engine.stopSSHServer()
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, engine.sshServer)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
@@ -1559,6 +1590,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -1586,16 +1618,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -124,6 +124,11 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||
return serverKey, loginResp, err
|
||||
@@ -150,6 +155,11 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
config.EnableSSHRemotePortForwarding,
|
||||
config.DisableSSHAuth,
|
||||
)
|
||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||
if err != nil {
|
||||
|
||||
@@ -666,7 +666,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
}
|
||||
}()
|
||||
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
if runtime.GOOS != "js" && conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -10,5 +11,8 @@ const (
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||
}
|
||||
|
||||
@@ -21,9 +21,9 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/relay"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const eventQueueSize = 10
|
||||
@@ -67,6 +67,7 @@ type State struct {
|
||||
BytesRx int64
|
||||
Latency time.Duration
|
||||
RosenpassEnabled bool
|
||||
SSHHostKey []byte
|
||||
routes map[string]struct{}
|
||||
}
|
||||
|
||||
@@ -572,6 +573,22 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeerSSHHostKey updates peer's SSH host key
|
||||
func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
peerState, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("peer doesn't exist")
|
||||
}
|
||||
|
||||
peerState.SSHHostKey = sshHostKey
|
||||
d.peers[peerPubKey] = peerState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FinishPeerListModifications this event invoke the notification
|
||||
func (d *Status) FinishPeerListModifications() {
|
||||
d.mux.Lock()
|
||||
|
||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
if isController(w.config) {
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
|
||||
@@ -44,8 +44,6 @@ type PKCEAuthProviderConfig struct {
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -44,24 +44,29 @@ var DefaultInterfaceBlacklist = []string{
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
ManagementURL string
|
||||
AdminURL string
|
||||
ConfigPath string
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
NATExternalIPs []string
|
||||
CustomDNSAddress []byte
|
||||
RosenpassEnabled *bool
|
||||
RosenpassPermissive *bool
|
||||
InterfaceName *string
|
||||
WireguardPort *int
|
||||
NetworkMonitor *bool
|
||||
DisableAutoConnect *bool
|
||||
ExtraIFaceBlackList []string
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
|
||||
DisableClientRoutes *bool
|
||||
DisableServerRoutes *bool
|
||||
@@ -82,18 +87,23 @@ type ConfigInput struct {
|
||||
// Config Configuration type
|
||||
type Config struct {
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
ManagementURL *url.URL
|
||||
AdminURL *url.URL
|
||||
WgIface string
|
||||
WgPort int
|
||||
NetworkMonitor *bool
|
||||
IFaceBlackList []string
|
||||
DisableIPv6Discovery bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
EnableSSHRemotePortForwarding *bool
|
||||
DisableSSHAuth *bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -376,6 +386,56 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
} else {
|
||||
log.Infof("disabling SSH root login")
|
||||
}
|
||||
config.EnableSSHRoot = input.EnableSSHRoot
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||
if *input.EnableSSHSFTP {
|
||||
log.Infof("enabling SSH SFTP subsystem")
|
||||
} else {
|
||||
log.Infof("disabling SSH SFTP subsystem")
|
||||
}
|
||||
config.EnableSSHSFTP = input.EnableSSHSFTP
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||
if *input.EnableSSHLocalPortForwarding {
|
||||
log.Infof("enabling SSH local port forwarding")
|
||||
} else {
|
||||
log.Infof("disabling SSH local port forwarding")
|
||||
}
|
||||
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||
if *input.EnableSSHRemotePortForwarding {
|
||||
log.Infof("enabling SSH remote port forwarding")
|
||||
} else {
|
||||
log.Infof("disabling SSH remote port forwarding")
|
||||
}
|
||||
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||
if *input.DisableSSHAuth {
|
||||
log.Infof("disabling SSH authentication")
|
||||
} else {
|
||||
log.Infof("enabling SSH authentication")
|
||||
}
|
||||
config.DisableSSHAuth = input.DisableSSHAuth
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||
|
||||
@@ -193,10 +193,10 @@ func TestWireguardPortZeroExplicit(t *testing.T) {
|
||||
|
||||
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no port specified uses default",
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -20,8 +20,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
@@ -273,16 +273,19 @@ type LoginRequest struct {
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
// This is needed because the generated code
|
||||
// omits initialized empty slices due to omitempty tags
|
||||
CleanDNSLabels bool `protobuf:"varint,27,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
|
||||
LazyConnectionEnabled *bool `protobuf:"varint,28,opt,name=lazyConnectionEnabled,proto3,oneof" json:"lazyConnectionEnabled,omitempty"`
|
||||
BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"`
|
||||
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
|
||||
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
|
||||
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
CleanDNSLabels bool `protobuf:"varint,27,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
|
||||
LazyConnectionEnabled *bool `protobuf:"varint,28,opt,name=lazyConnectionEnabled,proto3,oneof" json:"lazyConnectionEnabled,omitempty"`
|
||||
BlockInbound *bool `protobuf:"varint,29,opt,name=block_inbound,json=blockInbound,proto3,oneof" json:"block_inbound,omitempty"`
|
||||
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
|
||||
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
|
||||
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
|
||||
EnableSSHRoot *bool `protobuf:"varint,33,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"`
|
||||
EnableSSHSFTP *bool `protobuf:"varint,34,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
|
||||
EnableSSHLocalPortForwarding *bool `protobuf:"varint,35,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
|
||||
EnableSSHRemotePortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||
DisableSSHAuth *bool `protobuf:"varint,37,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *LoginRequest) Reset() {
|
||||
@@ -540,11 +543,39 @@ func (x *LoginRequest) GetMtu() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetHint() string {
|
||||
if x != nil && x.Hint != nil {
|
||||
return *x.Hint
|
||||
func (x *LoginRequest) GetEnableSSHRoot() bool {
|
||||
if x != nil && x.EnableSSHRoot != nil {
|
||||
return *x.EnableSSHRoot
|
||||
}
|
||||
return ""
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetEnableSSHSFTP() bool {
|
||||
if x != nil && x.EnableSSHSFTP != nil {
|
||||
return *x.EnableSSHSFTP
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetEnableSSHLocalPortForwarding() bool {
|
||||
if x != nil && x.EnableSSHLocalPortForwarding != nil {
|
||||
return *x.EnableSSHLocalPortForwarding
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool {
|
||||
if x != nil && x.EnableSSHRemotePortForwarding != nil {
|
||||
return *x.EnableSSHRemotePortForwarding
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetDisableSSHAuth() bool {
|
||||
if x != nil && x.DisableSSHAuth != nil {
|
||||
return *x.DisableSSHAuth
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
@@ -1057,24 +1088,29 @@ type GetConfigResponse struct {
|
||||
// preSharedKey settings value.
|
||||
PreSharedKey string `protobuf:"bytes,4,opt,name=preSharedKey,proto3" json:"preSharedKey,omitempty"`
|
||||
// adminURL settings value.
|
||||
AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"`
|
||||
InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"`
|
||||
WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"`
|
||||
Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"`
|
||||
DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"`
|
||||
ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
|
||||
RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
|
||||
RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
|
||||
DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"`
|
||||
LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
|
||||
BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"`
|
||||
NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"`
|
||||
DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"`
|
||||
DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"`
|
||||
DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"`
|
||||
BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
AdminURL string `protobuf:"bytes,5,opt,name=adminURL,proto3" json:"adminURL,omitempty"`
|
||||
InterfaceName string `protobuf:"bytes,6,opt,name=interfaceName,proto3" json:"interfaceName,omitempty"`
|
||||
WireguardPort int64 `protobuf:"varint,7,opt,name=wireguardPort,proto3" json:"wireguardPort,omitempty"`
|
||||
Mtu int64 `protobuf:"varint,8,opt,name=mtu,proto3" json:"mtu,omitempty"`
|
||||
DisableAutoConnect bool `protobuf:"varint,9,opt,name=disableAutoConnect,proto3" json:"disableAutoConnect,omitempty"`
|
||||
ServerSSHAllowed bool `protobuf:"varint,10,opt,name=serverSSHAllowed,proto3" json:"serverSSHAllowed,omitempty"`
|
||||
RosenpassEnabled bool `protobuf:"varint,11,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
|
||||
RosenpassPermissive bool `protobuf:"varint,12,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"`
|
||||
DisableNotifications bool `protobuf:"varint,13,opt,name=disable_notifications,json=disableNotifications,proto3" json:"disable_notifications,omitempty"`
|
||||
LazyConnectionEnabled bool `protobuf:"varint,14,opt,name=lazyConnectionEnabled,proto3" json:"lazyConnectionEnabled,omitempty"`
|
||||
BlockInbound bool `protobuf:"varint,15,opt,name=blockInbound,proto3" json:"blockInbound,omitempty"`
|
||||
NetworkMonitor bool `protobuf:"varint,16,opt,name=networkMonitor,proto3" json:"networkMonitor,omitempty"`
|
||||
DisableDns bool `protobuf:"varint,17,opt,name=disable_dns,json=disableDns,proto3" json:"disable_dns,omitempty"`
|
||||
DisableClientRoutes bool `protobuf:"varint,18,opt,name=disable_client_routes,json=disableClientRoutes,proto3" json:"disable_client_routes,omitempty"`
|
||||
DisableServerRoutes bool `protobuf:"varint,19,opt,name=disable_server_routes,json=disableServerRoutes,proto3" json:"disable_server_routes,omitempty"`
|
||||
BlockLanAccess bool `protobuf:"varint,20,opt,name=block_lan_access,json=blockLanAccess,proto3" json:"block_lan_access,omitempty"`
|
||||
EnableSSHRoot bool `protobuf:"varint,21,opt,name=enableSSHRoot,proto3" json:"enableSSHRoot,omitempty"`
|
||||
EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
|
||||
EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
|
||||
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) Reset() {
|
||||
@@ -1247,6 +1283,41 @@ func (x *GetConfigResponse) GetBlockLanAccess() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetEnableSSHRoot() bool {
|
||||
if x != nil {
|
||||
return x.EnableSSHRoot
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetEnableSSHSFTP() bool {
|
||||
if x != nil {
|
||||
return x.EnableSSHSFTP
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetEnableSSHLocalPortForwarding() bool {
|
||||
if x != nil {
|
||||
return x.EnableSSHLocalPortForwarding
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool {
|
||||
if x != nil {
|
||||
return x.EnableSSHRemotePortForwarding
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetDisableSSHAuth() bool {
|
||||
if x != nil {
|
||||
return x.DisableSSHAuth
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
type PeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -1267,6 +1338,7 @@ type PeerState struct {
|
||||
Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"`
|
||||
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
|
||||
RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"`
|
||||
SshHostKey []byte `protobuf:"bytes,19,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -1420,6 +1492,13 @@ func (x *PeerState) GetRelayAddress() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *PeerState) GetSshHostKey() []byte {
|
||||
if x != nil {
|
||||
return x.SshHostKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
type LocalPeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -3711,11 +3790,16 @@ type SetConfigRequest struct {
|
||||
ExtraIFaceBlacklist []string `protobuf:"bytes,24,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"`
|
||||
DnsLabels []string `protobuf:"bytes,25,rep,name=dns_labels,json=dnsLabels,proto3" json:"dns_labels,omitempty"`
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
|
||||
DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"`
|
||||
Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
CleanDNSLabels bool `protobuf:"varint,26,opt,name=cleanDNSLabels,proto3" json:"cleanDNSLabels,omitempty"`
|
||||
DnsRouteInterval *durationpb.Duration `protobuf:"bytes,27,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"`
|
||||
Mtu *int64 `protobuf:"varint,28,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
|
||||
EnableSSHRoot *bool `protobuf:"varint,29,opt,name=enableSSHRoot,proto3,oneof" json:"enableSSHRoot,omitempty"`
|
||||
EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
|
||||
EnableSSHLocalPortForward *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForward,proto3,oneof" json:"enableSSHLocalPortForward,omitempty"`
|
||||
EnableSSHRemotePortForward *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForward,proto3,oneof" json:"enableSSHRemotePortForward,omitempty"`
|
||||
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) Reset() {
|
||||
@@ -3944,6 +4028,41 @@ func (x *SetConfigRequest) GetMtu() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) GetEnableSSHRoot() bool {
|
||||
if x != nil && x.EnableSSHRoot != nil {
|
||||
return *x.EnableSSHRoot
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) GetEnableSSHSFTP() bool {
|
||||
if x != nil && x.EnableSSHSFTP != nil {
|
||||
return *x.EnableSSHSFTP
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) GetEnableSSHLocalPortForward() bool {
|
||||
if x != nil && x.EnableSSHLocalPortForward != nil {
|
||||
return *x.EnableSSHLocalPortForward
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) GetEnableSSHRemotePortForward() bool {
|
||||
if x != nil && x.EnableSSHRemotePortForward != nil {
|
||||
return *x.EnableSSHRemotePortForward
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *SetConfigRequest) GetDisableSSHAuth() bool {
|
||||
if x != nil && x.DisableSSHAuth != nil {
|
||||
return *x.DisableSSHAuth
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type SetConfigResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
@@ -4560,6 +4679,381 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
|
||||
type GetPeerSSHHostKeyRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// peer IP address or FQDN to get SSH host key for
|
||||
PeerAddress string `protobuf:"bytes,1,opt,name=peerAddress,proto3" json:"peerAddress,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) Reset() {
|
||||
*x = GetPeerSSHHostKeyRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[69]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetPeerSSHHostKeyRequest) ProtoMessage() {}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[69]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead.
|
||||
func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{69}
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string {
|
||||
if x != nil {
|
||||
return x.PeerAddress
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
|
||||
type GetPeerSSHHostKeyResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
|
||||
SshHostKey []byte `protobuf:"bytes,1,opt,name=sshHostKey,proto3" json:"sshHostKey,omitempty"`
|
||||
// peer IP address
|
||||
PeerIP string `protobuf:"bytes,2,opt,name=peerIP,proto3" json:"peerIP,omitempty"`
|
||||
// peer FQDN
|
||||
PeerFQDN string `protobuf:"bytes,3,opt,name=peerFQDN,proto3" json:"peerFQDN,omitempty"`
|
||||
// indicates if the SSH host key was found
|
||||
Found bool `protobuf:"varint,4,opt,name=found,proto3" json:"found,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) Reset() {
|
||||
*x = GetPeerSSHHostKeyResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[70]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*GetPeerSSHHostKeyResponse) ProtoMessage() {}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[70]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead.
|
||||
func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{70}
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte {
|
||||
if x != nil {
|
||||
return x.SshHostKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) GetPeerIP() string {
|
||||
if x != nil {
|
||||
return x.PeerIP
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) GetPeerFQDN() string {
|
||||
if x != nil {
|
||||
return x.PeerFQDN
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) GetFound() bool {
|
||||
if x != nil {
|
||||
return x.Found
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||
type RequestJWTAuthRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthRequest) Reset() {
|
||||
*x = RequestJWTAuthRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RequestJWTAuthRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{71}
|
||||
}
|
||||
|
||||
// RequestJWTAuthResponse contains authentication flow information
|
||||
type RequestJWTAuthResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// verification URI for user authentication
|
||||
VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"`
|
||||
// complete verification URI (with embedded user code)
|
||||
VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"`
|
||||
// user code to enter on verification URI
|
||||
UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"`
|
||||
// device code for polling
|
||||
DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
|
||||
// expiration time in seconds
|
||||
ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
|
||||
// if a cached token is available, it will be returned here
|
||||
CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"`
|
||||
// maximum age of JWT tokens in seconds (from management server)
|
||||
MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) Reset() {
|
||||
*x = RequestJWTAuthResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RequestJWTAuthResponse) ProtoMessage() {}
|
||||
|
||||
func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
|
||||
func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{72}
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetVerificationURI() string {
|
||||
if x != nil {
|
||||
return x.VerificationURI
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string {
|
||||
if x != nil {
|
||||
return x.VerificationURIComplete
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetUserCode() string {
|
||||
if x != nil {
|
||||
return x.UserCode
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetDeviceCode() string {
|
||||
if x != nil {
|
||||
return x.DeviceCode
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetExpiresIn() int64 {
|
||||
if x != nil {
|
||||
return x.ExpiresIn
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetCachedToken() string {
|
||||
if x != nil {
|
||||
return x.CachedToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 {
|
||||
if x != nil {
|
||||
return x.MaxTokenAge
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// WaitJWTTokenRequest for waiting for authentication completion
|
||||
type WaitJWTTokenRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// device code from RequestJWTAuthResponse
|
||||
DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
|
||||
// user code for verification
|
||||
UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenRequest) Reset() {
|
||||
*x = WaitJWTTokenRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenRequest) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*WaitJWTTokenRequest) ProtoMessage() {}
|
||||
|
||||
func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
|
||||
func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{73}
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenRequest) GetDeviceCode() string {
|
||||
if x != nil {
|
||||
return x.DeviceCode
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenRequest) GetUserCode() string {
|
||||
if x != nil {
|
||||
return x.UserCode
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// WaitJWTTokenResponse contains the JWT token after authentication
|
||||
type WaitJWTTokenResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// JWT token (access token or ID token)
|
||||
Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
|
||||
// token type (e.g., "Bearer")
|
||||
TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"`
|
||||
// expiration time in seconds
|
||||
ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) Reset() {
|
||||
*x = WaitJWTTokenResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*WaitJWTTokenResponse) ProtoMessage() {}
|
||||
|
||||
func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
|
||||
func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{74}
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) GetToken() string {
|
||||
if x != nil {
|
||||
return x.Token
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) GetTokenType() string {
|
||||
if x != nil {
|
||||
return x.TokenType
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
|
||||
if x != nil {
|
||||
return x.ExpiresIn
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type PortInfo_Range struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||
@@ -4570,7 +5064,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[70]
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -4582,7 +5076,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[70]
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -4617,7 +5111,7 @@ var File_daemon_proto protoreflect.FileDescriptor
|
||||
const file_daemon_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
|
||||
"\fEmptyRequest\"\xe5\x0e\n" +
|
||||
"\fEmptyRequest\"\xd4\x11\n" +
|
||||
"\fLoginRequest\x12\x1a\n" +
|
||||
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
|
||||
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
|
||||
@@ -4654,8 +5148,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
|
||||
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
|
||||
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
|
||||
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
|
||||
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
|
||||
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12)\n" +
|
||||
"\renableSSHRoot\x18! \x01(\bH\x14R\renableSSHRoot\x88\x01\x01\x12)\n" +
|
||||
"\renableSSHSFTP\x18\" \x01(\bH\x15R\renableSSHSFTP\x88\x01\x01\x12G\n" +
|
||||
"\x1cenableSSHLocalPortForwarding\x18# \x01(\bH\x16R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
|
||||
"\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" +
|
||||
"\x0edisableSSHAuth\x18% \x01(\bH\x18R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
|
||||
"\x11_rosenpassEnabledB\x10\n" +
|
||||
"\x0e_interfaceNameB\x10\n" +
|
||||
"\x0e_wireguardPortB\x17\n" +
|
||||
@@ -4675,8 +5173,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x0e_block_inboundB\x0e\n" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_usernameB\x06\n" +
|
||||
"\x04_mtuB\a\n" +
|
||||
"\x05_hint\"\xb5\x01\n" +
|
||||
"\x04_mtuB\x10\n" +
|
||||
"\x0e_enableSSHRootB\x10\n" +
|
||||
"\x0e_enableSSHSFTPB\x1f\n" +
|
||||
"\x1d_enableSSHLocalPortForwardingB \n" +
|
||||
"\x1e_enableSSHRemotePortForwardingB\x11\n" +
|
||||
"\x0f_disableSSHAuth\"\xb5\x01\n" +
|
||||
"\rLoginResponse\x12$\n" +
|
||||
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
|
||||
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
|
||||
@@ -4709,7 +5211,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\fDownResponse\"P\n" +
|
||||
"\x10GetConfigRequest\x12 \n" +
|
||||
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xb5\x06\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xb3\b\n" +
|
||||
"\x11GetConfigResponse\x12$\n" +
|
||||
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
|
||||
"\n" +
|
||||
@@ -4734,7 +5236,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"disableDns\x122\n" +
|
||||
"\x15disable_client_routes\x18\x12 \x01(\bR\x13disableClientRoutes\x122\n" +
|
||||
"\x15disable_server_routes\x18\x13 \x01(\bR\x13disableServerRoutes\x12(\n" +
|
||||
"\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\"\xde\x05\n" +
|
||||
"\x10block_lan_access\x18\x14 \x01(\bR\x0eblockLanAccess\x12$\n" +
|
||||
"\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" +
|
||||
"\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" +
|
||||
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
|
||||
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
|
||||
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\"\xfe\x05\n" +
|
||||
"\tPeerState\x12\x0e\n" +
|
||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
|
||||
@@ -4755,7 +5262,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x10rosenpassEnabled\x18\x0f \x01(\bR\x10rosenpassEnabled\x12\x1a\n" +
|
||||
"\bnetworks\x18\x10 \x03(\tR\bnetworks\x123\n" +
|
||||
"\alatency\x18\x11 \x01(\v2\x19.google.protobuf.DurationR\alatency\x12\"\n" +
|
||||
"\frelayAddress\x18\x12 \x01(\tR\frelayAddress\"\xf0\x01\n" +
|
||||
"\frelayAddress\x18\x12 \x01(\tR\frelayAddress\x12\x1e\n" +
|
||||
"\n" +
|
||||
"sshHostKey\x18\x13 \x01(\fR\n" +
|
||||
"sshHostKey\"\xf0\x01\n" +
|
||||
"\x0eLocalPeerState\x12\x0e\n" +
|
||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12(\n" +
|
||||
@@ -4934,7 +5444,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_username\"\x17\n" +
|
||||
"\x15SwitchProfileResponse\"\x8e\r\n" +
|
||||
"\x15SwitchProfileResponse\"\x8d\x10\n" +
|
||||
"\x10SetConfigRequest\x12\x1a\n" +
|
||||
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
|
||||
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
|
||||
@@ -4967,7 +5477,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"dns_labels\x18\x19 \x03(\tR\tdnsLabels\x12&\n" +
|
||||
"\x0ecleanDNSLabels\x18\x1a \x01(\bR\x0ecleanDNSLabels\x12J\n" +
|
||||
"\x10dnsRouteInterval\x18\x1b \x01(\v2\x19.google.protobuf.DurationH\x10R\x10dnsRouteInterval\x88\x01\x01\x12\x15\n" +
|
||||
"\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01B\x13\n" +
|
||||
"\x03mtu\x18\x1c \x01(\x03H\x11R\x03mtu\x88\x01\x01\x12)\n" +
|
||||
"\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" +
|
||||
"\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12A\n" +
|
||||
"\x19enableSSHLocalPortForward\x18\x1f \x01(\bH\x14R\x19enableSSHLocalPortForward\x88\x01\x01\x12C\n" +
|
||||
"\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01\x12+\n" +
|
||||
"\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
|
||||
"\x11_rosenpassEnabledB\x10\n" +
|
||||
"\x0e_interfaceNameB\x10\n" +
|
||||
"\x0e_wireguardPortB\x17\n" +
|
||||
@@ -4985,7 +5500,12 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x16_lazyConnectionEnabledB\x10\n" +
|
||||
"\x0e_block_inboundB\x13\n" +
|
||||
"\x11_dnsRouteIntervalB\x06\n" +
|
||||
"\x04_mtu\"\x13\n" +
|
||||
"\x04_mtuB\x10\n" +
|
||||
"\x0e_enableSSHRootB\x10\n" +
|
||||
"\x0e_enableSSHSFTPB\x1c\n" +
|
||||
"\x1a_enableSSHLocalPortForwardB\x1d\n" +
|
||||
"\x1b_enableSSHRemotePortForwardB\x11\n" +
|
||||
"\x0f_disableSSHAuth\"\x13\n" +
|
||||
"\x11SetConfigResponse\"Q\n" +
|
||||
"\x11AddProfileRequest\x12\x1a\n" +
|
||||
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
|
||||
@@ -5015,7 +5535,36 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x12GetFeaturesRequest\"x\n" +
|
||||
"\x13GetFeaturesResponse\x12)\n" +
|
||||
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
|
||||
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" +
|
||||
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"<\n" +
|
||||
"\x18GetPeerSSHHostKeyRequest\x12 \n" +
|
||||
"\vpeerAddress\x18\x01 \x01(\tR\vpeerAddress\"\x85\x01\n" +
|
||||
"\x19GetPeerSSHHostKeyResponse\x12\x1e\n" +
|
||||
"\n" +
|
||||
"sshHostKey\x18\x01 \x01(\fR\n" +
|
||||
"sshHostKey\x12\x16\n" +
|
||||
"\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" +
|
||||
"\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" +
|
||||
"\x05found\x18\x04 \x01(\bR\x05found\"\x17\n" +
|
||||
"\x15RequestJWTAuthRequest\"\x9a\x02\n" +
|
||||
"\x16RequestJWTAuthResponse\x12(\n" +
|
||||
"\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" +
|
||||
"\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" +
|
||||
"\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" +
|
||||
"\n" +
|
||||
"deviceCode\x18\x04 \x01(\tR\n" +
|
||||
"deviceCode\x12\x1c\n" +
|
||||
"\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" +
|
||||
"\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" +
|
||||
"\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" +
|
||||
"\x13WaitJWTTokenRequest\x12\x1e\n" +
|
||||
"\n" +
|
||||
"deviceCode\x18\x01 \x01(\tR\n" +
|
||||
"deviceCode\x12\x1a\n" +
|
||||
"\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" +
|
||||
"\x14WaitJWTTokenResponse\x12\x14\n" +
|
||||
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
||||
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
||||
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
|
||||
"\bLogLevel\x12\v\n" +
|
||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||
"\x05PANIC\x10\x01\x12\t\n" +
|
||||
@@ -5024,7 +5573,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x04WARN\x10\x04\x12\b\n" +
|
||||
"\x04INFO\x10\x05\x12\t\n" +
|
||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||
"\x05TRACE\x10\a2\x8f\x10\n" +
|
||||
"\x05TRACE\x10\a2\x8b\x12\n" +
|
||||
"\rDaemonService\x126\n" +
|
||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||
@@ -5056,7 +5605,10 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" +
|
||||
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
|
||||
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
|
||||
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
|
||||
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
||||
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
||||
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||
|
||||
var (
|
||||
file_daemon_proto_rawDescOnce sync.Once
|
||||
@@ -5071,7 +5623,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 78)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
|
||||
@@ -5145,18 +5697,24 @@ var file_daemon_proto_goTypes = []any{
|
||||
(*LogoutResponse)(nil), // 69: daemon.LogoutResponse
|
||||
(*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest
|
||||
(*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
|
||||
nil, // 72: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range
|
||||
nil, // 74: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 75: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 72: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 73: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 74: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 75: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 76: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 77: daemon.WaitJWTTokenResponse
|
||||
nil, // 78: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 79: daemon.PortInfo.Range
|
||||
nil, // 80: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 81: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 82: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
81, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
82, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
82, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
81, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||
@@ -5165,8 +5723,8 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
78, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
79, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
@@ -5177,10 +5735,10 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
82, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
80, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
81, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
@@ -5211,36 +5769,42 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
59, // [59:87] is the sub-list for method output_type
|
||||
31, // [31:59] is the sub-list for method input_type
|
||||
72, // 59: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
74, // 60: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
76, // 61: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
5, // 62: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
7, // 63: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
9, // 64: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
11, // 65: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||
13, // 66: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||
15, // 67: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||
24, // 68: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||
26, // 69: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
26, // 70: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||
31, // 71: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||
33, // 72: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||
35, // 73: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||
37, // 74: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||
40, // 75: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||
42, // 76: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||
44, // 77: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
46, // 78: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
50, // 79: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
52, // 80: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
54, // 81: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
56, // 82: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
58, // 83: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||
60, // 84: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||
62, // 85: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||
64, // 86: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||
67, // 87: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
69, // 88: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
71, // 89: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
73, // 90: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
75, // 91: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
77, // 92: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
62, // [62:93] is the sub-list for method output_type
|
||||
31, // [31:62] is the sub-list for method input_type
|
||||
31, // [31:31] is the sub-list for extension type_name
|
||||
31, // [31:31] is the sub-list for extension extendee
|
||||
0, // [0:31] is the sub-list for field type_name
|
||||
@@ -5269,7 +5833,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 3,
|
||||
NumMessages: 72,
|
||||
NumMessages: 78,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -84,6 +84,15 @@ service DaemonService {
|
||||
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
|
||||
|
||||
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -159,8 +168,11 @@ message LoginRequest {
|
||||
|
||||
optional int64 mtu = 32;
|
||||
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
optional string hint = 33;
|
||||
optional bool enableSSHRoot = 33;
|
||||
optional bool enableSSHSFTP = 34;
|
||||
optional bool enableSSHLocalPortForwarding = 35;
|
||||
optional bool enableSSHRemotePortForwarding = 36;
|
||||
optional bool disableSSHAuth = 37;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -188,7 +200,7 @@ message UpResponse {}
|
||||
|
||||
message StatusRequest{
|
||||
bool getFullPeerStatus = 1;
|
||||
bool shouldRunProbes = 2;
|
||||
bool shouldRunProbes = 2;
|
||||
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
|
||||
optional bool waitForReady = 3;
|
||||
}
|
||||
@@ -255,6 +267,16 @@ message GetConfigResponse {
|
||||
bool disable_server_routes = 19;
|
||||
|
||||
bool block_lan_access = 20;
|
||||
|
||||
bool enableSSHRoot = 21;
|
||||
|
||||
bool enableSSHSFTP = 24;
|
||||
|
||||
bool enableSSHLocalPortForwarding = 22;
|
||||
|
||||
bool enableSSHRemotePortForwarding = 23;
|
||||
|
||||
bool disableSSHAuth = 25;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -276,6 +298,7 @@ message PeerState {
|
||||
repeated string networks = 16;
|
||||
google.protobuf.Duration latency = 17;
|
||||
string relayAddress = 18;
|
||||
bytes sshHostKey = 19;
|
||||
}
|
||||
|
||||
// LocalPeerState contains the latest state of the local peer
|
||||
@@ -543,56 +566,62 @@ message SwitchProfileRequest {
|
||||
message SwitchProfileResponse {}
|
||||
|
||||
message SetConfigRequest {
|
||||
string username = 1;
|
||||
string profileName = 2;
|
||||
// managementUrl to authenticate.
|
||||
string managementUrl = 3;
|
||||
string username = 1;
|
||||
string profileName = 2;
|
||||
// managementUrl to authenticate.
|
||||
string managementUrl = 3;
|
||||
|
||||
// adminUrl to manage keys.
|
||||
string adminURL = 4;
|
||||
// adminUrl to manage keys.
|
||||
string adminURL = 4;
|
||||
|
||||
optional bool rosenpassEnabled = 5;
|
||||
optional bool rosenpassEnabled = 5;
|
||||
|
||||
optional string interfaceName = 6;
|
||||
optional string interfaceName = 6;
|
||||
|
||||
optional int64 wireguardPort = 7;
|
||||
optional int64 wireguardPort = 7;
|
||||
|
||||
optional string optionalPreSharedKey = 8;
|
||||
optional string optionalPreSharedKey = 8;
|
||||
|
||||
optional bool disableAutoConnect = 9;
|
||||
optional bool disableAutoConnect = 9;
|
||||
|
||||
optional bool serverSSHAllowed = 10;
|
||||
optional bool serverSSHAllowed = 10;
|
||||
|
||||
optional bool rosenpassPermissive = 11;
|
||||
optional bool rosenpassPermissive = 11;
|
||||
|
||||
optional bool networkMonitor = 12;
|
||||
optional bool networkMonitor = 12;
|
||||
|
||||
optional bool disable_client_routes = 13;
|
||||
optional bool disable_server_routes = 14;
|
||||
optional bool disable_dns = 15;
|
||||
optional bool disable_firewall = 16;
|
||||
optional bool block_lan_access = 17;
|
||||
optional bool disable_client_routes = 13;
|
||||
optional bool disable_server_routes = 14;
|
||||
optional bool disable_dns = 15;
|
||||
optional bool disable_firewall = 16;
|
||||
optional bool block_lan_access = 17;
|
||||
|
||||
optional bool disable_notifications = 18;
|
||||
optional bool disable_notifications = 18;
|
||||
|
||||
optional bool lazyConnectionEnabled = 19;
|
||||
optional bool lazyConnectionEnabled = 19;
|
||||
|
||||
optional bool block_inbound = 20;
|
||||
optional bool block_inbound = 20;
|
||||
|
||||
repeated string natExternalIPs = 21;
|
||||
bool cleanNATExternalIPs = 22;
|
||||
repeated string natExternalIPs = 21;
|
||||
bool cleanNATExternalIPs = 22;
|
||||
|
||||
bytes customDNSAddress = 23;
|
||||
bytes customDNSAddress = 23;
|
||||
|
||||
repeated string extraIFaceBlacklist = 24;
|
||||
repeated string extraIFaceBlacklist = 24;
|
||||
|
||||
repeated string dns_labels = 25;
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
bool cleanDNSLabels = 26;
|
||||
repeated string dns_labels = 25;
|
||||
// cleanDNSLabels clean map list of DNS labels.
|
||||
bool cleanDNSLabels = 26;
|
||||
|
||||
optional google.protobuf.Duration dnsRouteInterval = 27;
|
||||
optional google.protobuf.Duration dnsRouteInterval = 27;
|
||||
|
||||
optional int64 mtu = 28;
|
||||
optional int64 mtu = 28;
|
||||
|
||||
optional bool enableSSHRoot = 29;
|
||||
optional bool enableSSHSFTP = 30;
|
||||
optional bool enableSSHLocalPortForward = 31;
|
||||
optional bool enableSSHRemotePortForward = 32;
|
||||
optional bool disableSSHAuth = 33;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
@@ -644,3 +673,61 @@ message GetFeaturesResponse{
|
||||
bool disable_profiles = 1;
|
||||
bool disable_update_settings = 2;
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyRequest for retrieving SSH host key for a specific peer
|
||||
message GetPeerSSHHostKeyRequest {
|
||||
// peer IP address or FQDN to get SSH host key for
|
||||
string peerAddress = 1;
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKeyResponse contains the SSH host key for the requested peer
|
||||
message GetPeerSSHHostKeyResponse {
|
||||
// SSH host key in SSH public key format (e.g., "ssh-ed25519 AAAAC3... hostname")
|
||||
bytes sshHostKey = 1;
|
||||
// peer IP address
|
||||
string peerIP = 2;
|
||||
// peer FQDN
|
||||
string peerFQDN = 3;
|
||||
// indicates if the SSH host key was found
|
||||
bool found = 4;
|
||||
}
|
||||
|
||||
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||
message RequestJWTAuthRequest {
|
||||
}
|
||||
|
||||
// RequestJWTAuthResponse contains authentication flow information
|
||||
message RequestJWTAuthResponse {
|
||||
// verification URI for user authentication
|
||||
string verificationURI = 1;
|
||||
// complete verification URI (with embedded user code)
|
||||
string verificationURIComplete = 2;
|
||||
// user code to enter on verification URI
|
||||
string userCode = 3;
|
||||
// device code for polling
|
||||
string deviceCode = 4;
|
||||
// expiration time in seconds
|
||||
int64 expiresIn = 5;
|
||||
// if a cached token is available, it will be returned here
|
||||
string cachedToken = 6;
|
||||
// maximum age of JWT tokens in seconds (from management server)
|
||||
int64 maxTokenAge = 7;
|
||||
}
|
||||
|
||||
// WaitJWTTokenRequest for waiting for authentication completion
|
||||
message WaitJWTTokenRequest {
|
||||
// device code from RequestJWTAuthResponse
|
||||
string deviceCode = 1;
|
||||
// user code for verification
|
||||
string userCode = 2;
|
||||
}
|
||||
|
||||
// WaitJWTTokenResponse contains the JWT token after authentication
|
||||
message WaitJWTTokenResponse {
|
||||
// JWT token (access token or ID token)
|
||||
string token = 1;
|
||||
// token type (e.g., "Bearer")
|
||||
string tokenType = 2;
|
||||
// expiration time in seconds
|
||||
int64 expiresIn = 3;
|
||||
}
|
||||
|
||||
@@ -64,6 +64,12 @@ type DaemonServiceClient interface {
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
|
||||
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -349,6 +355,33 @@ func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRe
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) {
|
||||
out := new(GetPeerSSHHostKeyResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetPeerSSHHostKey", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
|
||||
out := new(RequestJWTAuthResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
|
||||
out := new(WaitJWTTokenResponse)
|
||||
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility
|
||||
@@ -399,6 +432,12 @@ type DaemonServiceServer interface {
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
|
||||
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -490,6 +529,15 @@ func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest)
|
||||
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
|
||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
@@ -1010,6 +1058,60 @@ func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, de
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(GetPeerSSHHostKeyRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/GetPeerSSHHostKey",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).GetPeerSSHHostKey(ctx, req.(*GetPeerSSHHostKeyRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RequestJWTAuthRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(WaitJWTTokenRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/daemon.DaemonService/WaitJWTToken",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1125,6 +1227,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetFeatures",
|
||||
Handler: _DaemonService_GetFeatures_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetPeerSSHHostKey",
|
||||
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RequestJWTAuth",
|
||||
Handler: _DaemonService_RequestJWTAuth_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "WaitJWTToken",
|
||||
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
73
client/server/jwt_cache.go
Normal file
73
client/server/jwt_cache.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/awnumar/memguard"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type jwtCache struct {
|
||||
mu sync.RWMutex
|
||||
enclave *memguard.Enclave
|
||||
expiresAt time.Time
|
||||
timer *time.Timer
|
||||
maxTokenSize int
|
||||
}
|
||||
|
||||
func newJWTCache() *jwtCache {
|
||||
return &jwtCache{
|
||||
maxTokenSize: 8192,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *jwtCache) store(token string, maxAge time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cleanup()
|
||||
|
||||
if c.timer != nil {
|
||||
c.timer.Stop()
|
||||
}
|
||||
|
||||
tokenBytes := []byte(token)
|
||||
c.enclave = memguard.NewEnclave(tokenBytes)
|
||||
|
||||
c.expiresAt = time.Now().Add(maxAge)
|
||||
|
||||
c.timer = time.AfterFunc(maxAge, func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.cleanup()
|
||||
c.timer = nil
|
||||
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *jwtCache) get() (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.enclave == nil || time.Now().After(c.expiresAt) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
buffer, err := c.enclave.Open()
|
||||
if err != nil {
|
||||
log.Debugf("Failed to open JWT token enclave: %v", err)
|
||||
return "", false
|
||||
}
|
||||
defer buffer.Destroy()
|
||||
|
||||
token := string(buffer.Bytes())
|
||||
return token, true
|
||||
}
|
||||
|
||||
// cleanup destroys the secure enclave, must be called with lock held
|
||||
func (c *jwtCache) cleanup() {
|
||||
if c.enclave != nil {
|
||||
c.enclave = nil
|
||||
}
|
||||
}
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type selectRoute struct {
|
||||
|
||||
@@ -46,6 +46,9 @@ const (
|
||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||
defaultRetryMultiplier = 1.7
|
||||
|
||||
// JWT token cache TTL for the client daemon
|
||||
defaultJWTCacheTTL = 5 * time.Minute
|
||||
|
||||
errRestoreResidualState = "failed to restore residual state: %v"
|
||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||
@@ -81,6 +84,8 @@ type Server struct {
|
||||
profileManager *profilemanager.ServiceManager
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
|
||||
jwtCache *jwtCache
|
||||
}
|
||||
|
||||
type oauthAuthFlow struct {
|
||||
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
||||
profileManager: profilemanager.NewServiceManager(configFile),
|
||||
profilesDisabled: profilesDisabled,
|
||||
updateSettingsDisabled: updateSettingsDisabled,
|
||||
jwtCache: newJWTCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,6 +379,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.DisableNotifications = msg.DisableNotifications
|
||||
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
config.BlockInbound = msg.BlockInbound
|
||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
|
||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
|
||||
if msg.DisableSSHAuth != nil {
|
||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
@@ -483,17 +496,13 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
if msg.SetupKey == "" {
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
|
||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||
log.Debugf("using previous oauth flow info")
|
||||
return &proto.LoginResponse{
|
||||
@@ -510,7 +519,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||
return nil, err
|
||||
@@ -1071,6 +1080,189 @@ func (s *Server) Status(
|
||||
return &statusResponse, nil
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
req *proto.GetPeerSSHHostKeyRequest,
|
||||
) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
statusRecorder := s.statusRecorder
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil, errors.New("client not initialized")
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
}
|
||||
|
||||
peerAddress := req.GetPeerAddress()
|
||||
hostKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
|
||||
response := &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
}
|
||||
|
||||
if !found {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
response.SshHostKey = hostKey
|
||||
|
||||
if statusRecorder == nil {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
fullStatus := statusRecorder.GetFullStatus()
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||
response.PeerIP = peerState.IP
|
||||
response.PeerFQDN = peerState.FQDN
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
|
||||
// NB_SSH_JWT_CACHE_TTL=0 disables caching
|
||||
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
|
||||
func getJWTCacheTTL() time.Duration {
|
||||
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
|
||||
if envValue == "" {
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
seconds, err := strconv.Atoi(envValue)
|
||||
if err != nil {
|
||||
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
|
||||
return defaultJWTCacheTTL
|
||||
}
|
||||
|
||||
if seconds == 0 {
|
||||
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
|
||||
return 0
|
||||
}
|
||||
|
||||
ttl := time.Duration(seconds) * time.Second
|
||||
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
|
||||
return ttl
|
||||
}
|
||||
|
||||
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||
func (s *Server) RequestJWTAuth(
|
||||
ctx context.Context,
|
||||
_ *proto.RequestJWTAuthRequest,
|
||||
) (*proto.RequestJWTAuthResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
config := s.config
|
||||
s.mutex.Unlock()
|
||||
|
||||
if config == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||
}
|
||||
|
||||
jwtCacheTTL := getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
if cachedToken, found := s.jwtCache.get(); found {
|
||||
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: cachedToken,
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
isDesktop := isUnixRunningDesktop()
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||
}
|
||||
|
||||
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow.flow = oAuthFlow
|
||||
s.oauthAuthFlow.info = authInfo
|
||||
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||
s.mutex.Unlock()
|
||||
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
VerificationURI: authInfo.VerificationURI,
|
||||
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||
UserCode: authInfo.UserCode,
|
||||
DeviceCode: authInfo.DeviceCode,
|
||||
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// WaitJWTToken waits for JWT authentication completion
|
||||
func (s *Server) WaitJWTToken(
|
||||
ctx context.Context,
|
||||
req *proto.WaitJWTTokenRequest,
|
||||
) (*proto.WaitJWTTokenResponse, error) {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
oAuthFlow := s.oauthAuthFlow.flow
|
||||
authInfo := s.oauthAuthFlow.info
|
||||
s.mutex.Unlock()
|
||||
|
||||
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
|
||||
}
|
||||
|
||||
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||
if err != nil {
|
||||
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
|
||||
jwtCacheTTL := getJWTCacheTTL()
|
||||
if jwtCacheTTL > 0 {
|
||||
s.jwtCache.store(token, jwtCacheTTL)
|
||||
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||
} else {
|
||||
log.Debug("JWT caching disabled, not storing token")
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.oauthAuthFlow = oauthAuthFlow{}
|
||||
s.mutex.Unlock()
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: tokenInfo.GetTokenToUse(),
|
||||
TokenType: tokenInfo.TokenType,
|
||||
ExpiresIn: int64(tokenInfo.ExpiresIn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
}
|
||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||
}
|
||||
|
||||
func (s *Server) runProbes(waitForProbeResult bool) {
|
||||
if s.connectClient == nil {
|
||||
return
|
||||
@@ -1136,25 +1328,55 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
disableServerRoutes := cfg.DisableServerRoutes
|
||||
blockLANAccess := cfg.BlockLANAccess
|
||||
|
||||
enableSSHRoot := false
|
||||
if s.config.EnableSSHRoot != nil {
|
||||
enableSSHRoot = *s.config.EnableSSHRoot
|
||||
}
|
||||
|
||||
enableSSHSFTP := false
|
||||
if s.config.EnableSSHSFTP != nil {
|
||||
enableSSHSFTP = *s.config.EnableSSHSFTP
|
||||
}
|
||||
|
||||
enableSSHLocalPortForwarding := false
|
||||
if s.config.EnableSSHLocalPortForwarding != nil {
|
||||
enableSSHLocalPortForwarding = *s.config.EnableSSHLocalPortForwarding
|
||||
}
|
||||
|
||||
enableSSHRemotePortForwarding := false
|
||||
if s.config.EnableSSHRemotePortForwarding != nil {
|
||||
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
|
||||
}
|
||||
|
||||
disableSSHAuth := false
|
||||
if s.config.DisableSSHAuth != nil {
|
||||
disableSSHAuth = *s.config.DisableSSHAuth
|
||||
}
|
||||
|
||||
return &proto.GetConfigResponse{
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
ManagementUrl: managementURL.String(),
|
||||
PreSharedKey: preSharedKey,
|
||||
AdminURL: adminURL.String(),
|
||||
InterfaceName: cfg.WgIface,
|
||||
WireguardPort: int64(cfg.WgPort),
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
DisableDns: disableDNS,
|
||||
DisableClientRoutes: disableClientRoutes,
|
||||
DisableServerRoutes: disableServerRoutes,
|
||||
BlockLanAccess: blockLANAccess,
|
||||
EnableSSHRoot: enableSSHRoot,
|
||||
EnableSSHSFTP: enableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1385,6 +1607,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
@@ -14,9 +14,6 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -293,6 +290,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
@@ -313,16 +311,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
|
||||
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -167,30 +167,35 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
"EnableSSHRoot": true,
|
||||
"EnableSSHSFTP": true,
|
||||
"EnableSSHLocalPortForward": true,
|
||||
"EnableSSHRemotePortForward": true,
|
||||
"DisableSSHAuth": true,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
@@ -221,29 +226,34 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
// Map of CLI flag names to their corresponding SetConfigRequest field names.
|
||||
// This map must be updated when adding new config-related CLI flags.
|
||||
flagToField := map[string]string{
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
"enable-ssh-root": "EnableSSHRoot",
|
||||
"enable-ssh-sftp": "EnableSSHSFTP",
|
||||
"enable-ssh-local-port-forwarding": "EnableSSHLocalPortForward",
|
||||
"enable-ssh-remote-port-forwarding": "EnableSSHRemotePortForward",
|
||||
"disable-ssh-auth": "DisableSSHAuth",
|
||||
}
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&dns.ShutdownState{})
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/ssh/config"
|
||||
)
|
||||
|
||||
func registerStates(mgr *statemanager.Manager) {
|
||||
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
|
||||
mgr.RegisterState(&systemops.ShutdownState{})
|
||||
mgr.RegisterState(&nftables.ShutdownState{})
|
||||
mgr.RegisterState(&iptables.ShutdownState{})
|
||||
mgr.RegisterState(&config.ShutdownState{})
|
||||
}
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client to simplify usage
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
}
|
||||
|
||||
// Close closes the wrapped SSH Client
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// OpenTerminal starts an interactive terminal session with the remote SSH server
|
||||
func (c *Client) OpenTerminal() error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open new session: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
fd := int(os.Stdout.Fd())
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run raw terminal: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := term.Restore(fd, state)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("terminal get size: %s", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("failed requesting pty session with xterm: %s", err)
|
||||
}
|
||||
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("failed to start login shell on the remote host: %s", err)
|
||||
}
|
||||
|
||||
if err := session.Wait(); err != nil {
|
||||
if e, ok := err.(*ssh.ExitError); ok {
|
||||
if e.ExitStatus() == 130 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed running SSH session: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
|
||||
func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 5 * time.Second,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
|
||||
}
|
||||
|
||||
return Dial("tcp", addr, config)
|
||||
}
|
||||
|
||||
// Dial connects to the remote SSH server.
|
||||
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
client, err := ssh.Dial(network, addr, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
692
client/ssh/client/client.go
Normal file
692
client/ssh/client/client.go
Normal file
@@ -0,0 +1,692 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDaemonAddr is the default address for the NetBird daemon
|
||||
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
|
||||
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
|
||||
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client for simplified SSH operations
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
terminalState *term.State
|
||||
terminalFd int
|
||||
|
||||
windowsStdoutMode uint32 // nolint:unused
|
||||
windowsStdinMode uint32 // nolint:unused
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
log.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
return c.waitForSession(ctx, session)
|
||||
}
|
||||
|
||||
// setupSessionIO connects session streams to local terminal
|
||||
func (c *Client) setupSessionIO(session *ssh.Session) {
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
}
|
||||
|
||||
// waitForSession waits for the session to complete with context cancellation
|
||||
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return c.handleSessionError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionError processes session termination errors
|
||||
func (c *Client) handleSessionError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return fmt.Errorf("session wait: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreTerminal restores the terminal to its original state
|
||||
func (c *Client) restoreTerminal() {
|
||||
if c.terminalState != nil {
|
||||
_ = term.Restore(c.terminalFd, c.terminalState)
|
||||
c.terminalState = nil
|
||||
c.terminalFd = 0
|
||||
}
|
||||
|
||||
if err := c.restoreWindowsConsoleState(); err != nil {
|
||||
log.Debugf("restore Windows console state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a command on the remote host and returns the output
|
||||
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
output, err := session.CombinedOutput(command)
|
||||
if err != nil {
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return output, fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
|
||||
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
|
||||
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return fmt.Errorf("setup terminal mode: %w", err)
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandError processes command execution errors, treating exit codes as normal
|
||||
func (c *Client) handleCommandError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupContextCancellation sets up context cancellation for a session
|
||||
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
_ = session.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
|
||||
// createSession creates a new SSH session with context cancellation setup
|
||||
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
cancel := c.setupContextCancellation(ctx, session)
|
||||
cleanup := func() {
|
||||
cancel()
|
||||
_ = session.Close()
|
||||
}
|
||||
|
||||
return session, cleanup, nil
|
||||
}
|
||||
|
||||
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
||||
func getDefaultDaemonAddr() string {
|
||||
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
||||
return addr
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultDaemonAddrWindows
|
||||
}
|
||||
return DefaultDaemonAddr
|
||||
}
|
||||
|
||||
// DialOptions contains options for SSH connections
|
||||
type DialOptions struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
DaemonAddr string
|
||||
SkipCachedToken bool
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
// Dial connects to the given ssh server with specified options
|
||||
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
}
|
||||
|
||||
if opts.IdentityFile != "" {
|
||||
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SSH key auth: %w", err)
|
||||
}
|
||||
config.Auth = append(config.Auth, authMethod)
|
||||
}
|
||||
|
||||
daemonAddr := opts.DaemonAddr
|
||||
if daemonAddr == "" {
|
||||
daemonAddr = getDefaultDaemonAddr()
|
||||
}
|
||||
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||
}
|
||||
|
||||
// dialSSH establishes an SSH connection without JWT authentication
|
||||
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Debugf("connection close after handshake failure: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
client := ssh.NewClient(clientConn, chans, reqs)
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
return dialSSH(ctx, network, addr, config)
|
||||
}
|
||||
|
||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||
}
|
||||
|
||||
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
||||
return dialSSH(ctx, network, addr, configWithJWT)
|
||||
}
|
||||
|
||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
|
||||
}
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
|
||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||
func getKnownHostsFiles() []string {
|
||||
var files []string
|
||||
|
||||
// User's known_hosts file (highest priority)
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
|
||||
files = append(files, userKnownHosts)
|
||||
}
|
||||
|
||||
// NetBird managed known_hosts files
|
||||
if runtime.GOOS == "windows" {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
|
||||
files = append(files, netbirdKnownHosts)
|
||||
} else {
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts")
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
// createHostKeyCallback creates a host key verification callback
|
||||
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
if opts.InsecureSkipVerify {
|
||||
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
if daemonAddr == "" {
|
||||
return fmt.Errorf("no daemon address")
|
||||
}
|
||||
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
|
||||
}
|
||||
|
||||
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
|
||||
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
|
||||
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
|
||||
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
|
||||
}
|
||||
|
||||
func getKnownHostsFilesList(knownHostsFile string) []string {
|
||||
if knownHostsFile != "" {
|
||||
return []string{knownHostsFile}
|
||||
}
|
||||
return getKnownHostsFiles()
|
||||
}
|
||||
|
||||
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
return hostKeyCallbacks
|
||||
}
|
||||
|
||||
// createSSHKeyAuth creates SSH key authentication from a private key file
|
||||
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
|
||||
keyData, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse SSH private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
|
||||
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
|
||||
localListener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", localAddr, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := localListener.Close(); err != nil {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
localConn, err := localListener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
channel, err := c.client.Dial("tcp", remoteAddr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "administratively prohibited") {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
|
||||
} else {
|
||||
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("local forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("local forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
|
||||
host, port, err := c.parseRemoteAddress(remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
req := c.buildTCPIPForwardRequest(host, port)
|
||||
if err := c.sendTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("setup remote forward: %w", err)
|
||||
}
|
||||
|
||||
go c.handleRemoteForwardChannels(ctx, localAddr)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if err := c.cancelTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("cancel tcpip-forward: %w", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// parseRemoteAddress parses host and port from remote address string
|
||||
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
|
||||
host, portStr, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
return host, uint32(port), nil
|
||||
}
|
||||
|
||||
// buildTCPIPForwardRequest creates a tcpip-forward request message
|
||||
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
|
||||
return tcpipForwardMsg{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
|
||||
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send tcpip-forward request: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancelTCPIPForwardRequest cancels the tcpip-forward request
|
||||
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
|
||||
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
|
||||
// Get the channel once - subsequent calls return nil!
|
||||
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
|
||||
if channelRequests == nil {
|
||||
log.Debugf("forwarded-tcpip channel type already being handled")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan := <-channelRequests:
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("remote forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("remote forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
512
client/ssh/client/client_test.go
Normal file
512
client/ssh/client/client_test.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and start server
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Test Dial
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Verify client is connected
|
||||
assert.NotNil(t, client.client)
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate client key for multiple connections
|
||||
|
||||
const numClients = 3
|
||||
clients := make([]*Client, numClients)
|
||||
|
||||
for i := 0; i < numClients; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
cancel()
|
||||
require.NoError(t, err, "Client %d should connect successfully", i)
|
||||
clients[i] = client
|
||||
}
|
||||
|
||||
for i, client := range clients {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "Client %d should close without error", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if client != nil {
|
||||
require.NoError(t, client.Close(), "Client should close without error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_TerminalState(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Nil(t, client.terminalState)
|
||||
assert.Equal(t, 0, client.terminalFd)
|
||||
|
||||
client.restoreTerminal()
|
||||
assert.Nil(t, client.terminalState)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.OpenTerminal(ctx)
|
||||
// In test environment without a real terminal, this may complete quickly or timeout
|
||||
// Both behaviors are acceptable for testing terminal state management
|
||||
if err != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "console"),
|
||||
"Should timeout or have console error on Windows")
|
||||
} else {
|
||||
// On Unix systems in test environment, we may get various errors
|
||||
// including timeouts or terminal-related errors
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "terminal") ||
|
||||
strings.Contains(err.Error(), "pty"),
|
||||
"Expected timeout or terminal-related error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return server, serverAddr, client
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwarding(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("local forwarding times out gracefully", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "connection"),
|
||||
"Expected context or connection error")
|
||||
})
|
||||
|
||||
t.Run("remote forwarding denied", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "denied") ||
|
||||
strings.Contains(err.Error(), "disabled"),
|
||||
"Should be denied by default")
|
||||
})
|
||||
|
||||
t.Run("invalid addresses fail", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping data transfer test in short mode")
|
||||
}
|
||||
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Port forwarding requires the actual current user, not test user
|
||||
realUser, err := getRealCurrentUser()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Skip if running as system account that can't do port forwarding
|
||||
if testutil.IsSystemAccount(realUser) {
|
||||
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
||||
}
|
||||
|
||||
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
|
||||
InsecureSkipVerify: true, // Skip host key verification for test
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := testServer.Close(); err != nil {
|
||||
t.Logf("test server close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "Hello, World!"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("connection read error: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("connection write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
if err := localListener.Close(); err != nil {
|
||||
t.Logf("local listener close error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
err := client.LocalPortForward(ctx, localAddr, testServerAddr)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
if isWindowsPrivilegeError(err) {
|
||||
t.Logf("Port forward failed due to Windows privilege restrictions: %v", err)
|
||||
} else {
|
||||
t.Logf("Port forward error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Logf("set read deadline error: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(expectedResponse), n)
|
||||
assert.Equal(t, expectedResponse, string(response))
|
||||
}
|
||||
|
||||
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
||||
func getRealCurrentUser() (string, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
}
|
||||
|
||||
if username := os.Getenv("USER"); username != "" {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unable to determine current user")
|
||||
}
|
||||
|
||||
// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions
|
||||
func isWindowsPrivilegeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD
|
||||
strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess)
|
||||
strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser)
|
||||
strings.Contains(errStr, "privilege") ||
|
||||
strings.Contains(errStr, "access denied") ||
|
||||
strings.Contains(errStr, "user authentication failed")
|
||||
}
|
||||
135
client/ssh/client/terminal_unix.go
Normal file
135
client/ssh/client/terminal_unix.go
Normal file
@@ -0,0 +1,135 @@
|
||||
//go:build !windows
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
|
||||
fd := int(os.Stdout.Fd())
|
||||
|
||||
if !term.IsTerminal(fd) {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
c.terminalState = state
|
||||
c.terminalFd = fd
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
|
||||
go func() {
|
||||
defer signal.Stop(sigChan)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
case sig := <-sigChan:
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
signal.Reset(sig)
|
||||
s, ok := sig.(syscall.Signal)
|
||||
if !ok {
|
||||
log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(syscall.Getpid(), s); err != nil {
|
||||
log.Debugf("kill process with signal %v: %v", s, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return c.setupTerminal(session, fd)
|
||||
}
|
||||
|
||||
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
w, h := 80, 24
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreWindowsConsoleState is a no-op on Unix systems
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get terminal size: %w", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
// Ctrl+C
|
||||
ssh.VINTR: 3,
|
||||
// Ctrl+\
|
||||
ssh.VQUIT: 28,
|
||||
// Backspace
|
||||
ssh.VERASE: 127,
|
||||
// Ctrl+U
|
||||
ssh.VKILL: 21,
|
||||
// Ctrl+D
|
||||
ssh.VEOF: 4,
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
// Ctrl+Q
|
||||
ssh.VSTART: 17,
|
||||
// Ctrl+S
|
||||
ssh.VSTOP: 19,
|
||||
// Ctrl+Z
|
||||
ssh.VSUSP: 26,
|
||||
// Ctrl+O
|
||||
ssh.VDISCARD: 15,
|
||||
// Ctrl+R
|
||||
ssh.VREPRINT: 18,
|
||||
// Ctrl+W
|
||||
ssh.VWERASE: 23,
|
||||
// Ctrl+V
|
||||
ssh.VLNEXT: 22,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
259
client/ssh/client/terminal_windows.go
Normal file
259
client/ssh/client/terminal_windows.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
enableProcessedInput = 0x0001
|
||||
enableLineInput = 0x0002
|
||||
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
|
||||
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
|
||||
enableVirtualTerminalInput = 0x0200
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
||||
)
|
||||
|
||||
// ConsoleUnavailableError indicates that Windows console handles are not available
|
||||
// (e.g., in CI environments where stdout/stdin are redirected)
|
||||
type ConsoleUnavailableError struct {
|
||||
Operation string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Error() string {
|
||||
return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type coord struct {
|
||||
x, y int16
|
||||
}
|
||||
|
||||
type smallRect struct {
|
||||
left, top, right, bottom int16
|
||||
}
|
||||
|
||||
type consoleScreenBufferInfo struct {
|
||||
size coord
|
||||
cursorPosition coord
|
||||
attributes uint16
|
||||
window smallRect
|
||||
maximumWindowSize coord
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
if err := c.saveWindowsConsoleState(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("console unavailable, continuing with defaults: %v", err)
|
||||
c.terminalFd = 0
|
||||
} else {
|
||||
return fmt.Errorf("save console state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.enableWindowsVirtualTerminal(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("virtual terminal unavailable: %v", err)
|
||||
} else {
|
||||
return fmt.Errorf("failed to enable virtual terminal: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
w, h := c.getWindowsConsoleSize()
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.ICRNL: 1,
|
||||
ssh.OPOST: 1,
|
||||
ssh.ONLCR: 1,
|
||||
ssh.ISIG: 1,
|
||||
ssh.ICANON: 1,
|
||||
ssh.VINTR: 3, // Ctrl+C
|
||||
ssh.VQUIT: 28, // Ctrl+\
|
||||
ssh.VERASE: 127, // Backspace
|
||||
ssh.VKILL: 21, // Ctrl+U
|
||||
ssh.VEOF: 4, // Ctrl+D
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
ssh.VSTART: 17, // Ctrl+Q
|
||||
ssh.VSTOP: 19, // Ctrl+S
|
||||
ssh.VSUSP: 26, // Ctrl+Z
|
||||
ssh.VDISCARD: 15, // Ctrl+O
|
||||
ssh.VWERASE: 23, // Ctrl+W
|
||||
ssh.VLNEXT: 22, // Ctrl+V
|
||||
ssh.VREPRINT: 18, // Ctrl+R
|
||||
}
|
||||
|
||||
return session.RequestPty("xterm-256color", h, w, modes)
|
||||
}
|
||||
|
||||
func (c *Client) saveWindowsConsoleState() error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in saveWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
var stdoutMode, stdinMode uint32
|
||||
|
||||
ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdout console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdin console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 1
|
||||
c.windowsStdoutMode = stdoutMode
|
||||
c.windowsStdinMode = stdinMode
|
||||
|
||||
log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) enableWindowsVirtualTerminal() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
var mode uint32
|
||||
|
||||
ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode |= enableVirtualTerminalProcessing
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "enable virtual terminal processing",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
|
||||
mode |= enableVirtualTerminalInput
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "set stdin raw mode",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("enabled Windows virtual terminal processing")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getWindowsConsoleSize() (int, int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in getWindowsConsoleSize: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
var csbi consoleScreenBufferInfo
|
||||
|
||||
ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get console buffer info, using defaults: %v", err)
|
||||
return 80, 24
|
||||
}
|
||||
|
||||
width := int(csbi.window.right - csbi.window.left + 1)
|
||||
height := int(csbi.window.bottom - csbi.window.top + 1)
|
||||
|
||||
log.Debugf("Windows console size: %dx%d", width, height)
|
||||
return width, height
|
||||
}
|
||||
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
var err error
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if c.terminalFd != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdout console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdout console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdin console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdin console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 0
|
||||
c.windowsStdoutMode = 0
|
||||
c.windowsStdinMode = 0
|
||||
|
||||
log.Debugf("restored Windows console state")
|
||||
return err
|
||||
}
|
||||
167
client/ssh/common.go
Normal file
167
client/ssh/common.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
NetBirdSSHConfigFile = "99-netbird.conf"
|
||||
|
||||
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
||||
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPeerNotFound indicates the peer was not found in the network
|
||||
ErrPeerNotFound = errors.New("peer not found in network")
|
||||
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
||||
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
||||
)
|
||||
|
||||
// HostKeyVerifier provides SSH host key verification
|
||||
type HostKeyVerifier interface {
|
||||
VerifySSHHostKey(peerAddress string, key []byte) error
|
||||
}
|
||||
|
||||
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
||||
type DaemonHostKeyVerifier struct {
|
||||
client proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
||||
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
||||
return &DaemonHostKeyVerifier{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
||||
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: peerAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
return ErrPeerNotFound
|
||||
}
|
||||
|
||||
storedKeyData := response.GetSshHostKey()
|
||||
|
||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||
}
|
||||
|
||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
|
||||
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||
}
|
||||
|
||||
if useCache && authResponse.CachedToken != "" {
|
||||
log.Debug("Using cached authentication token")
|
||||
return authResponse.CachedToken, nil
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||
DeviceCode: authResponse.DeviceCode,
|
||||
UserCode: authResponse.UserCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for JWT token: %w", err)
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
||||
}
|
||||
return tokenResponse.Token, nil
|
||||
}
|
||||
|
||||
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
||||
// Returns nil only if the presented key matches the stored key.
|
||||
// Returns ErrNoStoredKey if storedKeyData is empty.
|
||||
// Returns an error if the keys don't match or if parsing fails.
|
||||
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
||||
if len(storedKeyData) == 0 {
|
||||
return ErrNoStoredKey
|
||||
}
|
||||
|
||||
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
||||
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
||||
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
||||
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
||||
configWithJWT := *config
|
||||
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
||||
return &configWithJWT
|
||||
}
|
||||
|
||||
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
||||
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
||||
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
presentedKey := key.Marshal()
|
||||
|
||||
for _, addr := range addresses {
|
||||
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
||||
if errors.Is(err, ErrPeerNotFound) {
|
||||
// Try other addresses for this peer
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Verified
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// buildAddressList creates a list of addresses to check for host key verification.
|
||||
// It includes the original hostname and extracts the host part from the remote address if different.
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
282
client/ssh/config/manager.go
Normal file
282
client/ssh/config/manager.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
||||
|
||||
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
||||
|
||||
MaxPeersForSSHConfig = 200
|
||||
|
||||
fileWriteTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
func isSSHConfigDisabled() bool {
|
||||
value := os.Getenv(EnvDisableSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
disabled, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
func isSSHConfigForced() bool {
|
||||
value := os.Getenv(EnvForceSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
forced, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
return forced
|
||||
}
|
||||
|
||||
// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count
|
||||
func shouldGenerateSSHConfig(peerCount int) bool {
|
||||
if isSSHConfigDisabled() {
|
||||
return false
|
||||
}
|
||||
|
||||
if isSSHConfigForced() {
|
||||
return true
|
||||
}
|
||||
|
||||
return peerCount <= MaxPeersForSSHConfig
|
||||
}
|
||||
|
||||
// writeFileWithTimeout writes data to a file with a timeout
|
||||
func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- os.WriteFile(filename, data, perm)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Manager handles SSH client configuration for NetBird peers
|
||||
type Manager struct {
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
}
|
||||
|
||||
// PeerSSHInfo represents a peer's SSH configuration information
|
||||
type PeerSSHInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
FQDN string
|
||||
}
|
||||
|
||||
// New creates a new SSH config manager
|
||||
func New() *Manager {
|
||||
sshConfigDir := getSystemSSHConfigDir()
|
||||
return &Manager{
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: nbssh.NetBirdSSHConfigFile,
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
|
||||
func getSystemSSHConfigDir() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return getWindowsSSHConfigDir()
|
||||
}
|
||||
return nbssh.UnixSSHConfigDir
|
||||
}
|
||||
|
||||
func getWindowsSSHConfigDir() string {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
|
||||
}
|
||||
|
||||
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
||||
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
|
||||
if !shouldGenerateSSHConfig(len(peers)) {
|
||||
m.logSkipReason(len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
sshConfig, err := m.buildSSHConfig(peers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build SSH config: %w", err)
|
||||
}
|
||||
return m.writeSSHConfig(sshConfig)
|
||||
}
|
||||
|
||||
func (m *Manager) logSkipReason(peerCount int) {
|
||||
if isSSHConfigDisabled() {
|
||||
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
|
||||
} else {
|
||||
log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.",
|
||||
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
|
||||
sshConfig := m.buildConfigHeader()
|
||||
|
||||
var allHostPatterns []string
|
||||
for _, peer := range peers {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
allHostPatterns = append(allHostPatterns, hostPatterns...)
|
||||
}
|
||||
|
||||
if len(allHostPatterns) > 0 {
|
||||
peerConfig, err := m.buildPeerConfig(allHostPatterns)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sshConfig += peerConfig
|
||||
}
|
||||
|
||||
return sshConfig, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildConfigHeader() string {
|
||||
return "# NetBird SSH client configuration\n" +
|
||||
"# Generated automatically - do not edit manually\n" +
|
||||
"#\n" +
|
||||
"# To disable SSH config management, use:\n" +
|
||||
"# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" +
|
||||
"#\n\n"
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
uniquePatterns := make(map[string]bool)
|
||||
var deduplicatedPatterns []string
|
||||
for _, pattern := range allHostPatterns {
|
||||
if !uniquePatterns[pattern] {
|
||||
uniquePatterns[pattern] = true
|
||||
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
execPath, err := m.getNetBirdExecutablePath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
var hostPatterns []string
|
||||
if peer.IP != "" {
|
||||
hostPatterns = append(hostPatterns, peer.IP)
|
||||
}
|
||||
if peer.FQDN != "" {
|
||||
hostPatterns = append(hostPatterns, peer.FQDN)
|
||||
}
|
||||
if peer.Hostname != "" && peer.Hostname != peer.FQDN {
|
||||
hostPatterns = append(hostPatterns, peer.Hostname)
|
||||
}
|
||||
return hostPatterns
|
||||
}
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveSSHClientConfig removes NetBird SSH configuration
|
||||
func (m *Manager) RemoveSSHClientConfig() error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
err := os.Remove(sshConfigPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
|
||||
}
|
||||
if err == nil {
|
||||
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) getNetBirdExecutablePath() (string, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("retrieve executable path: %w", err)
|
||||
}
|
||||
|
||||
realPath, err := filepath.EvalSymlinks(execPath)
|
||||
if err != nil {
|
||||
log.Debugf("symlink resolution failed: %v", err)
|
||||
return execPath, nil
|
||||
}
|
||||
|
||||
return realPath, nil
|
||||
}
|
||||
|
||||
// GetSSHConfigDir returns the SSH config directory path
|
||||
func (m *Manager) GetSSHConfigDir() string {
|
||||
return m.sshConfigDir
|
||||
}
|
||||
|
||||
// GetSSHConfigFile returns the SSH config file name
|
||||
func (m *Manager) GetSSHConfigFile() string {
|
||||
return m.sshConfigFile
|
||||
}
|
||||
159
client/ssh/config/manager_test.go
Normal file
159
client/ssh/config/manager_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Test SSH config generation with peers
|
||||
peers := []PeerSSHInfo{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read generated config
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
configStr := string(content)
|
||||
|
||||
// Verify the basic SSH config structure exists
|
||||
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
||||
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
||||
|
||||
// Check that peer hostnames are included
|
||||
assert.Contains(t, configStr, "100.125.1.1")
|
||||
assert.Contains(t, configStr, "100.125.1.2")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
assert.Contains(t, configStr, "peer2.nb.internal")
|
||||
|
||||
// Check platform-specific UserKnownHostsFile
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
|
||||
} else {
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSystemSSHConfigDir(t *testing.T) {
|
||||
configDir := getSystemSSHConfigDir()
|
||||
|
||||
// Path should not be empty
|
||||
assert.NotEmpty(t, configDir)
|
||||
|
||||
// Should be an absolute path
|
||||
assert.True(t, filepath.IsAbs(configDir))
|
||||
|
||||
// On Unix systems, should start with /etc
|
||||
// On Windows, should contain ProgramData
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
||||
} else {
|
||||
assert.Contains(t, configDir, "/etc/ssh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PeerLimit(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is skipped when too many peers
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should not be created due to peer limit
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
// Set force environment variable
|
||||
t.Setenv(EnvForceSSHConfig, "true")
|
||||
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is forced despite many peers
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should be created despite peer limit due to force flag
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err, "SSH config should be created when forced")
|
||||
|
||||
// Verify config contains peer hostnames
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
configStr := string(content)
|
||||
assert.Contains(t, configStr, "peer0.nb.internal")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
}
|
||||
22
client/ssh/config/shutdown_state.go
Normal file
22
client/ssh/config/shutdown_state.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package config
|
||||
|
||||
// ShutdownState represents SSH configuration state that needs to be cleaned up.
|
||||
type ShutdownState struct {
|
||||
SSHConfigDir string
|
||||
SSHConfigFile string
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "ssh_config_state"
|
||||
}
|
||||
|
||||
// Cleanup removes SSH client configuration files.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager := &Manager{
|
||||
sshConfigDir: s.SSHConfigDir,
|
||||
sshConfigFile: s.SSHConfigFile,
|
||||
}
|
||||
|
||||
return manager.RemoveSSHClientConfig()
|
||||
}
|
||||
99
client/ssh/detection/detection.go
Normal file
99
client/ssh/detection/detection.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// ServerIdentifier is the base response for NetBird SSH servers
|
||||
ServerIdentifier = "NetBird-SSH-Server"
|
||||
// ProxyIdentifier is the base response for NetBird SSH proxy
|
||||
ProxyIdentifier = "NetBird-SSH-Proxy"
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
|
||||
const (
|
||||
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
|
||||
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
|
||||
ServerTypeRegular ServerType = "regular"
|
||||
)
|
||||
|
||||
// Dialer provides network connection capabilities
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// RequiresJWT checks if the server type requires JWT authentication
|
||||
func (s ServerType) RequiresJWT() bool {
|
||||
return s == ServerTypeNetBirdJWT
|
||||
}
|
||||
|
||||
// ExitCode returns the exit code for the detect command
|
||||
func (s ServerType) ExitCode() int {
|
||||
switch s {
|
||||
case ServerTypeNetBirdJWT:
|
||||
return 0
|
||||
case ServerTypeNetBirdNoJWT:
|
||||
return 1
|
||||
case ServerTypeRegular:
|
||||
return 2
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// DetectSSHServerType detects SSH server type using the provided dialer
|
||||
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
|
||||
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
log.Debugf("SSH server banner: %s", serverBanner)
|
||||
|
||||
if !strings.HasPrefix(serverBanner, "SSH-") {
|
||||
log.Debugf("Invalid SSH banner")
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(serverBanner, ServerIdentifier) {
|
||||
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if strings.Contains(serverBanner, JWTRequiredMarker) {
|
||||
return ServerTypeNetBirdJWT, nil
|
||||
}
|
||||
|
||||
return ServerTypeNetBirdNoJWT, nil
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func isRoot() bool {
|
||||
return os.Geteuid() == 0
|
||||
}
|
||||
|
||||
func getLoginCmd(user string, remoteAddr net.Addr) (loginPath string, args []string, err error) {
|
||||
if !isRoot() {
|
||||
shell := getUserShell(user)
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
|
||||
return shell, []string{"-l"}, nil
|
||||
}
|
||||
|
||||
loginPath, err = exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
if util.FileExists("/etc/arch-release") && !util.FileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", user, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
case "darwin":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), user}, nil
|
||||
case "freebsd":
|
||||
return loginPath, []string{"-f", user, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !darwin
|
||||
// +build !darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import "os/user"
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
return user.Lookup(username)
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func userNameLookup(username string) (*user.User, error) {
|
||||
if username == "" || (username == "root" && !isRoot()) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
var userObject *user.User
|
||||
userObject, err := user.Lookup(username)
|
||||
if err != nil && err.Error() == user.UnknownUserError(username).Error() {
|
||||
return idUserNameLookup(username)
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userObject, nil
|
||||
}
|
||||
|
||||
func idUserNameLookup(username string) (*user.User, error) {
|
||||
cmd := exec.Command("id", "-P", username)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while retrieving user with id -P command, error: %v", err)
|
||||
}
|
||||
colon := ":"
|
||||
|
||||
if !bytes.Contains(out, []byte(username+colon)) {
|
||||
return nil, fmt.Errorf("unable to find user in returned string")
|
||||
}
|
||||
// netbird:********:501:20::0:0:netbird:/Users/netbird:/bin/zsh
|
||||
parts := strings.SplitN(string(out), colon, 10)
|
||||
userObject := &user.User{
|
||||
Username: parts[0],
|
||||
Uid: parts[2],
|
||||
Gid: parts[3],
|
||||
Name: parts[7],
|
||||
HomeDir: parts[8],
|
||||
}
|
||||
return userObject, nil
|
||||
}
|
||||
369
client/ssh/proxy/proxy.go
Normal file
369
client/ssh/proxy/proxy.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
|
||||
sshConnectionTimeout = 120 * time.Second
|
||||
// sshHandshakeTimeout is the timeout for SSH handshake completion
|
||||
sshHandshakeTimeout = 30 * time.Second
|
||||
|
||||
jwtAuthErrorMsg = "JWT authentication: %w"
|
||||
)
|
||||
|
||||
type SSHProxy struct {
|
||||
daemonAddr string
|
||||
targetHost string
|
||||
targetPort int
|
||||
stderr io.Writer
|
||||
daemonClient proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
|
||||
return &SSHProxy{
|
||||
daemonAddr: daemonAddr,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
stderr: stderr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||
}
|
||||
|
||||
return p.runProxySSHServer(ctx, jwtToken)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
||||
|
||||
sshServer := &ssh.Server{
|
||||
Handler: func(s ssh.Session) {
|
||||
p.handleSSHSession(ctx, s, jwtToken)
|
||||
},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": p.directTCPIPHandler,
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": func(s ssh.Session) {
|
||||
p.sftpSubsystemHandler(s, jwtToken)
|
||||
},
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": p.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
|
||||
},
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
hostKey, err := generateHostKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
sshServer.HostSigners = []ssh.Signer{hostKey}
|
||||
|
||||
conn := &stdioConn{
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
sshServer.HandleConn(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = sshClient.Close() }()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
|
||||
log.Debugf("PTY request to backend: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
|
||||
log.Debugf("window change: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if len(session.Command()) > 0 {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err = serverSession.Shell(); err != nil {
|
||||
log.Debugf("start shell: %v", err)
|
||||
return
|
||||
}
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("session wait: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func generateHostKey() (ssh.Signer, error) {
|
||||
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ED25519 key: %w", err)
|
||||
}
|
||||
|
||||
signer, err := cryptossh.ParsePrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
type stdioConn struct {
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *stdioConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdin.Read(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdout.Write(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) LocalAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) RemoteAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
ctx, cancel := context.WithCancel(s.Context())
|
||||
defer cancel()
|
||||
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
log.Debugf("close SSH client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("close server session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
stdin, stdout, err := p.setupSFTPPipes(serverSession)
|
||||
if err != nil {
|
||||
log.Debugf("setup SFTP pipes: %v", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := serverSession.RequestSubsystem("sftp"); err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
|
||||
stdin, err := serverSession.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := serverSession.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
return stdin, stdout, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
|
||||
copyErrCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(stdin, s)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP client to server copy: %v", err)
|
||||
}
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("close stdin: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(s, stdout)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP server to client copy: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("force close server session on context cancellation: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("SFTP copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("SFTP session ended: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return false, []byte("port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
|
||||
Timeout: sshHandshakeTimeout,
|
||||
HostKeyCallback: p.verifyHostKey,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: sshConnectionTimeout,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to server: %w", err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
return cryptossh.NewClient(clientConn, chans, reqs), nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
367
client/ssh/proxy/proxy_test.go
Normal file
367
client/ssh/proxy/proxy_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" {
|
||||
if os.Args[2] == "exec" {
|
||||
if len(os.Args) > 3 {
|
||||
cmd := os.Args[3]
|
||||
if cmd == "echo" && len(os.Args) > 4 {
|
||||
fmt.Fprintln(os.Stdout, os.Args[4])
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||
t.Run("calls daemon to verify host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
testPubKey, err := nbssh.GeneratePublicKey(testKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey("test-host", testPubKey)
|
||||
|
||||
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("rejects unknown host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// TODO: Windows test times out - user switching and command execution tested on Linux
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
jwtToken string
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
key, found := m.hostKeys[req.PeerAddress]
|
||||
return &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
SshHostKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockDaemon struct {
|
||||
addr string
|
||||
server *grpc.Server
|
||||
impl *mockDaemonServer
|
||||
}
|
||||
|
||||
func startMockDaemon(t *testing.T) *mockDaemon {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
impl := &mockDaemonServer{
|
||||
hostKeys: make(map[string][]byte),
|
||||
jwtToken: "test-jwt-token",
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
proto.RegisterDaemonServiceServer(grpcServer, impl)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
|
||||
return &mockDaemon{
|
||||
addr: listener.Addr().String(),
|
||||
server: grpcServer,
|
||||
impl: impl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||
m.impl.hostKeys[addr] = pubKey
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func (m *mockDaemon) stop() {
|
||||
if m.server != nil {
|
||||
m.server.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||
t.Helper()
|
||||
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
|
||||
require.NoError(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 44338
|
||||
|
||||
// TerminalTimeout is the timeout for terminal session to be ready
|
||||
const TerminalTimeout = 10 * time.Second
|
||||
|
||||
// TerminalBackoffDelay is the delay between terminal session readiness checks
|
||||
const TerminalBackoffDelay = 500 * time.Millisecond
|
||||
|
||||
// DefaultSSHServer is a function that creates DefaultServer
|
||||
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
|
||||
return newDefaultServer(hostKeyPEM, addr)
|
||||
}
|
||||
|
||||
// Server is an interface of SSH server
|
||||
type Server interface {
|
||||
// Stop stops SSH server.
|
||||
Stop() error
|
||||
// Start starts SSH server. Blocking
|
||||
Start() error
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
RemoveAuthorizedKey(peer string)
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
AddAuthorizedKey(peer, newKey string) error
|
||||
}
|
||||
|
||||
// DefaultServer is the embedded NetBird SSH server
|
||||
type DefaultServer struct {
|
||||
listener net.Listener
|
||||
// authorizedKeys is ssh pub key indexed by peer WireGuard public key
|
||||
authorizedKeys map[string]ssh.PublicKey
|
||||
mu sync.Mutex
|
||||
hostKeyPEM []byte
|
||||
sessions []ssh.Session
|
||||
}
|
||||
|
||||
// newDefaultServer creates new server with provided host key
|
||||
func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allowedKeys := make(map[string]ssh.PublicKey)
|
||||
return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil
|
||||
}
|
||||
|
||||
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
|
||||
func (srv *DefaultServer) RemoveAuthorizedKey(peer string) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
delete(srv.authorizedKeys, peer)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey add a given peer key to server authorized keys
|
||||
func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srv.authorizedKeys[peer] = parsedKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops SSH server.
|
||||
func (srv *DefaultServer) Stop() error {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
err := srv.listener.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, session := range srv.sessions {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
log.Warnf("failed closing SSH session from %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
for _, allowed := range srv.authorizedKeys {
|
||||
if ssh.KeysEqual(allowed, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
return []string{
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
}
|
||||
}
|
||||
|
||||
func acceptEnv(s string) bool {
|
||||
split := strings.Split(s, "=")
|
||||
if len(split) != 2 {
|
||||
return false
|
||||
}
|
||||
return split[0] == "TERM" || split[0] == "LANG" || strings.HasPrefix(split[0], "LC_")
|
||||
}
|
||||
|
||||
// sessionHandler handles SSH session post auth
|
||||
func (srv *DefaultServer) sessionHandler(session ssh.Session) {
|
||||
srv.mu.Lock()
|
||||
srv.sessions = append(srv.sessions, session)
|
||||
srv.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
err := session.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("Establishing SSH session for %s from host %s", session.User(), session.RemoteAddr().String())
|
||||
|
||||
localUser, err := userNameLookup(session.User())
|
||||
if err != nil {
|
||||
_, err = fmt.Fprintf(session, "remote SSH server couldn't find local user %s\n", session.User()) //nolint
|
||||
err = session.Exit(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
log.Warnf("failed SSH session from %v, user %s", session.RemoteAddr(), session.User())
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
loginCmd, loginArgs, err := getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Warnf("failed logging-in user %s from remote IP %s", localUser.Username, session.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(loginCmd, loginArgs...)
|
||||
go func() {
|
||||
<-session.Context().Done()
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
err := cmd.Process.Kill()
|
||||
if err != nil {
|
||||
log.Debugf("failed killing SSH process %v", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, prepareUserEnv(localUser, getUserShell(localUser.Uid))...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
cmd.Env = append(cmd.Env, v)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Login command: %s", cmd.String())
|
||||
file, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
log.Errorf("failed starting SSH server: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinSize(file, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
|
||||
srv.stdInOut(file, session)
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
_, err := io.WriteString(session, "only PTY is supported.\n")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = session.Exit(1)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("SSH session ended")
|
||||
}
|
||||
|
||||
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
|
||||
go func() {
|
||||
// stdin
|
||||
_, err := io.Copy(file, session)
|
||||
if err != nil {
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
// AWS Linux 2 machines need some time to open the terminal so we need to wait for it
|
||||
timer := time.NewTimer(TerminalTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-timer.C:
|
||||
_, _ = session.Write([]byte("Reached timeout while opening connection\n"))
|
||||
_ = session.Exit(1)
|
||||
return
|
||||
default:
|
||||
// stdout
|
||||
writtenBytes, err := io.Copy(session, file)
|
||||
if err != nil && writtenBytes != 0 {
|
||||
_ = session.Exit(0)
|
||||
return
|
||||
}
|
||||
time.Sleep(TerminalBackoffDelay)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts SSH server. Blocking
|
||||
func (srv *DefaultServer) Start() error {
|
||||
log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())
|
||||
|
||||
publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler)
|
||||
hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM)
|
||||
err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getUserShell(userID string) string {
|
||||
if runtime.GOOS == "linux" {
|
||||
output, _ := exec.Command("getent", "passwd", userID).Output()
|
||||
line := strings.SplitN(string(output), ":", 10)
|
||||
if len(line) > 6 {
|
||||
return strings.TrimSpace(line[6])
|
||||
}
|
||||
}
|
||||
|
||||
shell := os.Getenv("SHELL")
|
||||
if shell == "" {
|
||||
shell = "/bin/sh"
|
||||
}
|
||||
return shell
|
||||
}
|
||||
178
client/ssh/server/command_execution.go
Normal file
178
client/ssh/server/command_execution.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
if hasPty {
|
||||
commandType = "Pty command"
|
||||
}
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
|
||||
if hasPty {
|
||||
errorMsg += " with Pty"
|
||||
}
|
||||
errorMsg += "\n"
|
||||
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if s.executeCommand(logger, session, execCmd) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, error) {
|
||||
localUser := privilegeResult.User
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, err = s.createExecutorCommand(session, localUser, hasPty)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// executeCommand executes the command and handles I/O and exit codes
|
||||
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
|
||||
s.setupProcessGroup(execCmd)
|
||||
|
||||
stdinPipe, err := execCmd.StdinPipe()
|
||||
if err != nil {
|
||||
logger.Errorf("create stdin pipe: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
execCmd.Stdout = session
|
||||
execCmd.Stderr = session
|
||||
|
||||
if execCmd.Dir != "" {
|
||||
if _, err := os.Stat(execCmd.Dir); err != nil {
|
||||
logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
|
||||
execCmd.Dir = "/"
|
||||
}
|
||||
}
|
||||
|
||||
if err := execCmd.Start(); err != nil {
|
||||
logger.Errorf("command start failed: %v", err)
|
||||
// no user message for exec failure, just exit
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
go s.handleCommandIO(logger, stdinPipe, session)
|
||||
return s.waitForCommandCleanup(logger, session, execCmd)
|
||||
}
|
||||
|
||||
// handleCommandIO manages stdin/stdout copying in a goroutine
|
||||
func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
|
||||
defer func() {
|
||||
if err := stdinPipe.Close(); err != nil {
|
||||
logger.Debugf("stdin pipe close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(stdinPipe, session); err != nil {
|
||||
logger.Debugf("stdin copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCommandCleanup waits for command completion with session disconnect handling
|
||||
func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session cancelled, terminating command")
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
logger.Tracef("command terminated after session cancellation: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
|
||||
case err := <-done:
|
||||
return s.handleCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandCompletion handles command completion
|
||||
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
|
||||
if err != nil {
|
||||
logger.Debugf("command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return false
|
||||
}
|
||||
|
||||
s.handleSessionExit(session, nil, logger)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleSessionExit handles command errors and sets appropriate exit codes
|
||||
func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
|
||||
if err == nil {
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
if err := session.Exit(exitError.ExitCode()); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("non-exit error in command execution: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
36
client/ssh/server/command_execution_js.go
Normal file
36
client/ssh/server/command_execution_js.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
)
|
||||
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
}
|
||||
|
||||
// killProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) killProcessGroup(_ *exec.Cmd) {
|
||||
}
|
||||
278
client/ssh/server/command_execution_unix.go
Normal file
278
client/ssh/server/command_execution_unix.go
Normal file
@@ -0,0 +1,278 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
// TODO: handle pty flag if available
|
||||
args := []string{"-l", localUser.Username, "-c", command}
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-l"}
|
||||
}
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// ptyManager manages Pty file operations with thread safety
|
||||
type ptyManager struct {
|
||||
file *os.File
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
closeErr error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newPtyManager(file *os.File) *ptyManager {
|
||||
return &ptyManager{file: file}
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Close() error {
|
||||
pm.once.Do(func() {
|
||||
pm.mu.Lock()
|
||||
pm.closed = true
|
||||
pm.closeErr = pm.file.Close()
|
||||
pm.mu.Unlock()
|
||||
})
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
return pm.closeErr
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
if pm.closed {
|
||||
return errors.New("Pty is closed")
|
||||
}
|
||||
return pty.Setsize(pm.file, ws)
|
||||
}
|
||||
|
||||
func (pm *ptyManager) File() *os.File {
|
||||
return pm.file
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
errorMsg := "User switching failed - login command not available\r\n"
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
shell := execCmd.Path
|
||||
cmd := session.Command()
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty start failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
ptyMgr := newPtyManager(ptmx)
|
||||
defer func() {
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
|
||||
s.handlePtyIO(logger, session, ptyMgr)
|
||||
s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
|
||||
winSize := &pty.Winsize{
|
||||
Cols: uint16(ptyReq.Window.Width),
|
||||
Rows: uint16(ptyReq.Window.Height),
|
||||
}
|
||||
if winSize.Cols == 0 {
|
||||
winSize.Cols = 80
|
||||
}
|
||||
if winSize.Rows == 0 {
|
||||
winSize.Rows = 24
|
||||
}
|
||||
|
||||
ptmx, err := pty.StartWithSize(execCmd, winSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start Pty: %w", err)
|
||||
}
|
||||
|
||||
return ptmx, nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
|
||||
for {
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case win, ok := <-winCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
|
||||
logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
|
||||
ptmx := ptyMgr.File()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(ptmx, session); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty input copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
|
||||
logger.Debugf("Pty session cancelled, terminating command")
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close during session cancellation: %v", err)
|
||||
}
|
||||
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
|
||||
} else {
|
||||
logger.Debugf("Pty command terminated after session cancellation")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
pgid := cmd.Process.Pid
|
||||
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
logger.Debugf("kill process group SIGTERM failed: %v", err)
|
||||
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
||||
logger.Debugf("kill process group SIGKILL failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
410
client/ssh/server/command_execution_windows.go
Normal file
410
client/ssh/server/command_execution_windows.go
Normal file
@@ -0,0 +1,410 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/server/winpty"
|
||||
)
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
|
||||
var env []string
|
||||
for key, value := range envMap {
|
||||
env = append(env, key+"="+value)
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// loadUserProfile loads the Windows user profile and returns the profile path.
|
||||
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
|
||||
usernamePtr, err := windows.UTF16PtrFromString(username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert username to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
var domainUTF16 *uint16
|
||||
if domain != "" && domain != "." {
|
||||
domainUTF16, err = windows.UTF16PtrFromString(domain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type profileInfo struct {
|
||||
dwSize uint32
|
||||
dwFlags uint32
|
||||
lpUserName *uint16
|
||||
lpProfilePath *uint16
|
||||
lpDefaultPath *uint16
|
||||
lpServerName *uint16
|
||||
lpPolicyPath *uint16
|
||||
hProfile windows.Handle
|
||||
}
|
||||
|
||||
const PI_NOUI = 0x00000001
|
||||
|
||||
profile := profileInfo{
|
||||
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
|
||||
dwFlags: PI_NOUI,
|
||||
lpUserName: usernamePtr,
|
||||
lpServerName: domainUTF16,
|
||||
}
|
||||
|
||||
userenv := windows.NewLazySystemDLL("userenv.dll")
|
||||
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
|
||||
|
||||
ret, _, err := loadUserProfileW.Call(
|
||||
uintptr(userToken),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
)
|
||||
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("LoadUserProfileW: %w", err)
|
||||
}
|
||||
|
||||
if profile.lpProfilePath == nil {
|
||||
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
|
||||
}
|
||||
|
||||
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
|
||||
return profilePath, nil
|
||||
}
|
||||
|
||||
// loadSystemEnvironment loads system-wide environment variables from registry.
|
||||
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
|
||||
registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open system environment registry key: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := key.Close(); err != nil {
|
||||
log.Debugf("close registry key: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.readRegistryEnvironment(key, envMap)
|
||||
}
|
||||
|
||||
// readRegistryEnvironment reads environment variables from a registry key.
|
||||
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
|
||||
names, err := key.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read registry value names: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
value, valueType, err := key.GetStringValue(name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to read registry value %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
finalValue := s.expandRegistryValue(value, valueType, name)
|
||||
s.setEnvironmentVariable(envMap, name, finalValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandRegistryValue expands registry values if they contain environment variables.
|
||||
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
|
||||
if valueType != registry.EXPAND_SZ {
|
||||
return value
|
||||
}
|
||||
|
||||
sourcePtr := windows.StringToUTF16Ptr(value)
|
||||
expandedBuffer := make([]uint16, 1024)
|
||||
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s: %v", name, err)
|
||||
return value
|
||||
}
|
||||
|
||||
// If buffer was too small, retry with larger buffer
|
||||
if expandedLen > uint32(len(expandedBuffer)) {
|
||||
expandedBuffer = make([]uint16, expandedLen)
|
||||
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
|
||||
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// setEnvironmentVariable sets an environment variable with special handling for PATH.
|
||||
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
|
||||
upperName := strings.ToUpper(name)
|
||||
|
||||
if upperName == "PATH" {
|
||||
if existing, exists := envMap["PATH"]; exists && existing != value {
|
||||
envMap["PATH"] = existing + ";" + value
|
||||
} else {
|
||||
envMap["PATH"] = value
|
||||
}
|
||||
} else {
|
||||
envMap[upperName] = value
|
||||
}
|
||||
}
|
||||
|
||||
// setUserEnvironmentVariables sets critical user-specific environment variables.
|
||||
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
|
||||
envMap["USERPROFILE"] = userProfile
|
||||
|
||||
if len(userProfile) >= 2 && userProfile[1] == ':' {
|
||||
envMap["HOMEDRIVE"] = userProfile[:2]
|
||||
envMap["HOMEPATH"] = userProfile[2:]
|
||||
}
|
||||
|
||||
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
|
||||
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
|
||||
|
||||
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
|
||||
envMap["TEMP"] = tempDir
|
||||
envMap["TMP"] = tempDir
|
||||
|
||||
envMap["USERNAME"] = username
|
||||
if domain != "" && domain != "." {
|
||||
envMap["USERDOMAIN"] = domain
|
||||
envMap["USERDNSDOMAIN"] = domain
|
||||
}
|
||||
|
||||
systemVars := []string{
|
||||
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
|
||||
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
|
||||
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
|
||||
}
|
||||
|
||||
for _, sysVar := range systemVars {
|
||||
if sysValue := os.Getenv(sysVar); sysValue != "" {
|
||||
envMap[sysVar] = sysValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
// Always use user switching on Windows - no direct execution
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
return true
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-NoLogo"}
|
||||
}
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
localUser := privilegeResult.User
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
var command string
|
||||
rawCmd := session.RawCommand()
|
||||
if rawCmd != "" {
|
||||
command = rawCmd
|
||||
}
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
err := executePtyCommandWithUserToken(session.Context(), session, req)
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("Windows ConPty with user switching failed: %v", err)
|
||||
var errorMsg string
|
||||
if runtime.GOOS == "windows" {
|
||||
errorMsg = "Windows user switching failed - NetBird must run as a Windows service or with elevated privileges for user switching\r\n"
|
||||
} else {
|
||||
errorMsg = "User switching failed - login command not available\r\n"
|
||||
}
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debugf("Windows ConPty command execution with user switching completed")
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
Username string
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironment(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
workingDir := getUserHomeFromEnv(userEnv)
|
||||
if workingDir == "" {
|
||||
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
|
||||
}
|
||||
|
||||
ptyConfig := winpty.PtyConfig{
|
||||
Shell: req.Shell,
|
||||
Command: req.Command,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
WorkingDir: workingDir,
|
||||
}
|
||||
|
||||
userConfig := winpty.UserConfig{
|
||||
Token: userToken,
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
for _, envVar := range env {
|
||||
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
|
||||
return envVar[12:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
// Windows doesn't support process groups in the same way as Unix
|
||||
// Process creation groups are handled differently
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
}
|
||||
}
|
||||
722
client/ssh/server/compatibility_test.go
Normal file
722
client/ssh/server/compatibility_test.go
Normal file
@@ -0,0 +1,722 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH compatibility tests in short mode")
|
||||
}
|
||||
|
||||
// Check if ssh binary is available
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Set up SSH server - use our existing key generation for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate OpenSSH-compatible keys for client
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Create temporary key files for SSH client
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
// Extract host and port from server address
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get appropriate user for SSH connection (handle system accounts)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("basic command execution", func(t *testing.T) {
|
||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||
})
|
||||
|
||||
t.Run("interactive command", func(t *testing.T) {
|
||||
testSSHInteractiveCommand(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
|
||||
t.Run("port forwarding", func(t *testing.T) {
|
||||
testSSHPortForwarding(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
}
|
||||
|
||||
// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user.
|
||||
func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) {
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "hello_world")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("SSH command failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
|
||||
}
|
||||
|
||||
// testSSHInteractiveCommand tests interactive shell session.
|
||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host))
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create stdin pipe: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create stdout pipe: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
t.Logf("Cannot start SSH session: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := stdin.Close(); err != nil {
|
||||
t.Logf("stdin close error: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil {
|
||||
t.Logf("stdin write error: %v", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := stdin.Write([]byte("exit\n")); err != nil {
|
||||
t.Logf("stdin write error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
output, err := io.ReadAll(stdout)
|
||||
if err != nil {
|
||||
t.Logf("Cannot read SSH output: %v", err)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
t.Logf("SSH interactive session error: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work")
|
||||
}
|
||||
|
||||
// testSSHPortForwarding tests port forwarding compatibility.
|
||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer testServer.Close()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("test server connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("Test server read error: %v", err)
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("Test server write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
localListener.Close()
|
||||
|
||||
_, localPort, err := net.SplitHostPort(localAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr)
|
||||
cmd := exec.CommandContext(ctx, "ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-L", forwardSpec,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-N",
|
||||
fmt.Sprintf("%s@%s", username, host))
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
t.Logf("Cannot start SSH port forwarding: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if cmd.Process != nil {
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
t.Logf("process kill error: %v", err)
|
||||
}
|
||||
}
|
||||
if err := cmd.Wait(); err != nil {
|
||||
t.Logf("process wait after kill: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
|
||||
if err != nil {
|
||||
t.Logf("Cannot connect to forwarded port: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("forwarded connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
_, err = conn.Write([]byte(request))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
|
||||
log.Debugf("failed to set read deadline: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
if err != nil {
|
||||
t.Logf("Cannot read forwarded response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding")
|
||||
}
|
||||
|
||||
// isSSHClientAvailable checks if the ssh binary is available
|
||||
func isSSHClientAvailable() bool {
|
||||
_, err := exec.LookPath("ssh")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use.
|
||||
func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) {
|
||||
// Check if ssh-keygen is available
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
// Fall back to our existing key generation and try to convert
|
||||
return generateOpenSSHKeyFallback()
|
||||
}
|
||||
|
||||
// Create temporary file for ssh-keygen
|
||||
tempFile, err := os.CreateTemp("", "ssh_keygen_*")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
keyPath := tempFile.Name()
|
||||
tempFile.Close()
|
||||
|
||||
// Remove the temp file so ssh-keygen can create it
|
||||
if err := os.Remove(keyPath); err != nil {
|
||||
t.Logf("failed to remove key file: %v", err)
|
||||
}
|
||||
|
||||
// Clean up temp files
|
||||
defer func() {
|
||||
if err := os.Remove(keyPath); err != nil {
|
||||
t.Logf("failed to cleanup key file: %v", err)
|
||||
}
|
||||
if err := os.Remove(keyPath + ".pub"); err != nil {
|
||||
t.Logf("failed to cleanup public key file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Generate key using ssh-keygen
|
||||
cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
// Read private key
|
||||
privKeyBytes, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read private key: %w", err)
|
||||
}
|
||||
|
||||
// Read public key
|
||||
pubKeyBytes, err := os.ReadFile(keyPath + ".pub")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read public key: %w", err)
|
||||
}
|
||||
|
||||
return privKeyBytes, pubKeyBytes, nil
|
||||
}
|
||||
|
||||
// generateOpenSSHKeyFallback falls back to generating keys using our existing method
|
||||
func generateOpenSSHKeyFallback() ([]byte, []byte, error) {
|
||||
// Generate shared.ED25519 key pair using our existing method
|
||||
_, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
// Convert to SSH format
|
||||
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create signer: %w", err)
|
||||
}
|
||||
|
||||
// For the fallback, just use our PKCS#8 format and hope it works
|
||||
// This won't be in OpenSSH format but might still work with some SSH clients
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate fallback key: %w", err)
|
||||
}
|
||||
|
||||
// Get public key in SSH format
|
||||
sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey())
|
||||
|
||||
return hostKey, sshPubKey, nil
|
||||
}
|
||||
|
||||
// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes
|
||||
func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
tempFile, err := os.CreateTemp("", "ssh_test_key_*")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tempFile.Write(keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tempFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set proper permissions for SSH key (readable by owner only)
|
||||
err = os.Chmod(tempFile.Name(), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.Remove(tempFile.Name())
|
||||
}
|
||||
|
||||
return tempFile.Name(), cleanup
|
||||
}
|
||||
|
||||
// createTempKeyFile creates a temporary SSH private key file (for backward compatibility)
|
||||
func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Test various SSH features
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFunc func(t *testing.T, host, port, keyFile string)
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "command_with_flags",
|
||||
testFunc: testCommandWithFlags,
|
||||
description: "Commands with flags should work like standard SSH",
|
||||
},
|
||||
{
|
||||
name: "environment_variables",
|
||||
testFunc: testEnvironmentVariables,
|
||||
description: "Environment variables should be available",
|
||||
},
|
||||
{
|
||||
name: "exit_codes",
|
||||
testFunc: testExitCodes,
|
||||
description: "Exit codes should be properly handled",
|
||||
},
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.testFunc(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testCommandWithFlags tests that commands with flags work properly
|
||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test ls with flags
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"ls", "-la", "/tmp")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Command with flags failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// Should not be empty and should not contain error messages
|
||||
assert.NotEmpty(t, string(output), "ls -la should produce output")
|
||||
assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed")
|
||||
}
|
||||
|
||||
// testEnvironmentVariables tests that environment is properly set up
|
||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "$HOME")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Environment test failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// HOME environment variable should be available
|
||||
homeOutput := strings.TrimSpace(string(output))
|
||||
assert.NotEmpty(t, homeOutput, "HOME environment variable should be set")
|
||||
assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded")
|
||||
}
|
||||
|
||||
// testExitCodes tests that exit codes are properly handled
|
||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test successful command (exit code 0)
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"true") // always succeeds
|
||||
|
||||
err := cmd.Run()
|
||||
assert.NoError(t, err, "Command with exit code 0 should succeed")
|
||||
|
||||
// Test failing command (exit code 1)
|
||||
cmd = exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"false") // always fails
|
||||
|
||||
err = cmd.Run()
|
||||
assert.Error(t, err, "Command with exit code 1 should fail")
|
||||
|
||||
// Check if it's the right kind of error
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHServerSecurityFeatures tests security-related SSH features
|
||||
func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH security tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server with specific security settings
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("key_authentication", func(t *testing.T) {
|
||||
// Test that key authentication works
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "auth_success")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Key authentication failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "auth_success", "Key authentication should work")
|
||||
})
|
||||
|
||||
t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) {
|
||||
// Create a different key that shouldn't be accepted
|
||||
wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey)
|
||||
defer cleanupWrongKey()
|
||||
|
||||
// Test that wrong key is rejected
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", wrongKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "should_not_work")
|
||||
|
||||
err = cmd.Run()
|
||||
assert.NoError(t, err, "Any key should work in no-auth mode")
|
||||
})
|
||||
}
|
||||
|
||||
// TestCrossPlatformCompatibility tests cross-platform behavior
|
||||
func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping cross-platform compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test platform-specific commands
|
||||
var testCommand string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
testCommand = "echo %OS%"
|
||||
default:
|
||||
testCommand = "uname"
|
||||
}
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
testCommand)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Platform-specific command failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
outputStr := strings.TrimSpace(string(output))
|
||||
t.Logf("Platform command output: %s", outputStr)
|
||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||
}
|
||||
253
client/ssh/server/executor_unix.go
Normal file
253
client/ssh/server/executor_unix.go
Normal file
@@ -0,0 +1,253 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Exit codes for executor process communication
|
||||
const (
|
||||
ExitCodeSuccess = 0
|
||||
ExitCodePrivilegeDropFail = 10
|
||||
ExitCodeShellExecFail = 11
|
||||
ExitCodeValidationFail = 12
|
||||
)
|
||||
|
||||
// ExecutorConfig holds configuration for the executor process
|
||||
type ExecutorConfig struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
Groups []uint32
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
PTY bool
|
||||
}
|
||||
|
||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
// NewPrivilegeDropper creates a new privilege dropper
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) {
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get netbird executable path: %w", err)
|
||||
}
|
||||
|
||||
if err := pd.validatePrivileges(config.UID, config.GID); err != nil {
|
||||
return nil, fmt.Errorf("invalid privileges: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "exec",
|
||||
"--uid", fmt.Sprintf("%d", config.UID),
|
||||
"--gid", fmt.Sprintf("%d", config.GID),
|
||||
"--working-dir", config.WorkingDir,
|
||||
"--shell", config.Shell,
|
||||
}
|
||||
|
||||
for _, group := range config.Groups {
|
||||
args = append(args, "--groups", fmt.Sprintf("%d", group))
|
||||
}
|
||||
|
||||
if config.PTY {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
|
||||
if config.Command != "" {
|
||||
args = append(args, "--cmd", config.Command)
|
||||
}
|
||||
|
||||
// Log executor args safely - show all args except hide the command value
|
||||
safeArgs := make([]string, len(args))
|
||||
copy(safeArgs, args)
|
||||
for i := 0; i < len(safeArgs)-1; i++ {
|
||||
if safeArgs[i] == "--cmd" {
|
||||
cmdParts := strings.Fields(safeArgs[i+1])
|
||||
safeArgs[i+1] = safeLogCommand(cmdParts)
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||
}
|
||||
|
||||
// DropPrivileges performs privilege dropping with thread locking for security
|
||||
func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
|
||||
if err := pd.validatePrivileges(targetUID, targetGID); err != nil {
|
||||
return fmt.Errorf("invalid privileges: %w", err)
|
||||
}
|
||||
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
originalUID := os.Geteuid()
|
||||
originalGID := os.Getegid()
|
||||
|
||||
if originalUID != int(targetUID) || originalGID != int(targetGID) {
|
||||
if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil {
|
||||
return fmt.Errorf("set groups and IDs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setGroupsAndIDs sets the supplementary groups, GID, and UID
|
||||
func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
|
||||
groups := make([]int, len(supplementaryGroups))
|
||||
for i, g := range supplementaryGroups {
|
||||
groups[i] = int(g)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" {
|
||||
if len(groups) == 0 || groups[0] != int(targetGID) {
|
||||
groups = append([]int{int(targetGID)}, groups...)
|
||||
}
|
||||
}
|
||||
|
||||
if err := syscall.Setgroups(groups); err != nil {
|
||||
return fmt.Errorf("setgroups to %v: %w", groups, err)
|
||||
}
|
||||
|
||||
if err := syscall.Setgid(int(targetGID)); err != nil {
|
||||
return fmt.Errorf("setgid to %d: %w", targetGID, err)
|
||||
}
|
||||
|
||||
if err := syscall.Setuid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setuid to %d: %w", targetUID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePrivilegeDropSuccess validates that privilege dropping was successful
|
||||
func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error {
|
||||
if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePrivilegeDropReversibility ensures privileges cannot be restored
|
||||
func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error {
|
||||
if originalGID != int(targetGID) {
|
||||
if err := syscall.Setegid(originalGID); err == nil {
|
||||
return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID)
|
||||
}
|
||||
}
|
||||
if originalUID != int(targetUID) {
|
||||
if err := syscall.Seteuid(originalUID); err == nil {
|
||||
return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCurrentPrivileges validates the current UID and GID match the target
|
||||
func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error {
|
||||
currentUID := os.Geteuid()
|
||||
if currentUID != int(targetUID) {
|
||||
return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID)
|
||||
}
|
||||
|
||||
currentGID := os.Getegid()
|
||||
if currentGID != int(targetGID) {
|
||||
return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures
|
||||
func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) {
|
||||
log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
// TODO: Implement Pty support for executor path
|
||||
if config.PTY {
|
||||
config.PTY = false
|
||||
}
|
||||
|
||||
if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err)
|
||||
os.Exit(ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
var execCmd *exec.Cmd
|
||||
if config.Command == "" {
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
execCmd.Stdin = os.Stdin
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
if err := execCmd.Run(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
// Normal command exit with non-zero code - not an SSH execution error
|
||||
log.Tracef("command exited with code %d", exitError.ExitCode())
|
||||
os.Exit(exitError.ExitCode())
|
||||
}
|
||||
|
||||
// Actual execution failure (command not found, permission denied, etc.)
|
||||
log.Debugf("command execution failed: %v", err)
|
||||
os.Exit(ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
|
||||
// validatePrivileges validates that privilege dropping to the target UID/GID is allowed
|
||||
func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error {
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
// Allow same-user operations (no privilege dropping needed)
|
||||
if uid == currentUID && gid == currentGID {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only root can drop privileges to other users
|
||||
if currentUID != 0 {
|
||||
return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid)
|
||||
}
|
||||
|
||||
// Root can drop to any user (including root itself)
|
||||
return nil
|
||||
}
|
||||
262
client/ssh/server/executor_unix_test.go
Normal file
262
client/ssh/server/executor_unix_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uid uint32
|
||||
gid uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "same user - no privilege drop needed",
|
||||
uid: currentUID,
|
||||
gid: currentGID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-root to different user should fail",
|
||||
uid: currentUID + 1, // Use a different UID to ensure it's actually different
|
||||
gid: currentGID + 1, // Use a different GID to ensure it's actually different
|
||||
wantErr: currentUID != 0, // Only fail if current user is not root
|
||||
},
|
||||
{
|
||||
name: "root can drop to any user",
|
||||
uid: 1000,
|
||||
gid: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "root can stay as root",
|
||||
uid: 0,
|
||||
gid: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Skip non-root tests when running as root, and root tests when not root
|
||||
if tt.name == "non-root to different user should fail" && currentUID == 0 {
|
||||
t.Skip("Skipping non-root test when running as root")
|
||||
}
|
||||
if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 {
|
||||
t.Skip("Skipping root test when not running as root")
|
||||
}
|
||||
|
||||
err := pd.validatePrivileges(tt.uid, tt.gid)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
|
||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||
// This test requires root privileges and will be skipped if not running as root
|
||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("This test requires root privileges")
|
||||
}
|
||||
|
||||
// Find a non-root user to test with
|
||||
testUser, err := findNonRootUser()
|
||||
if err != nil {
|
||||
t.Skip("No suitable non-root user found for testing")
|
||||
}
|
||||
|
||||
// Verify the user actually exists by looking it up again
|
||||
_, err = user.LookupId(testUser.Uid)
|
||||
if err != nil {
|
||||
t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err)
|
||||
}
|
||||
|
||||
uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
targetUID := uint32(uid64)
|
||||
|
||||
gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
targetGID := uint32(gid64)
|
||||
|
||||
// Test in a child process to avoid affecting the test runner
|
||||
if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
// This should succeed
|
||||
err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we are now running as the target user
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
assert.Equal(t, targetUID, currentUID, "UID should match target")
|
||||
assert.Equal(t, targetGID, currentGID, "GID should match target")
|
||||
assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
|
||||
assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Fork a child process to test privilege dropping
|
||||
cmd := os.Args[0]
|
||||
args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
|
||||
|
||||
env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
|
||||
|
||||
execCmd := exec.Command(cmd, args...)
|
||||
execCmd.Env = env
|
||||
|
||||
err = execCmd.Run()
|
||||
require.NoError(t, err, "Child process should succeed")
|
||||
}
|
||||
|
||||
// findNonRootUser finds any non-root user on the system for testing
|
||||
func findNonRootUser() (*user.User, error) {
|
||||
// Try common non-root users, but avoid "nobody" on macOS due to negative UID issues
|
||||
commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
|
||||
|
||||
for _, username := range commonUsers {
|
||||
if u, err := user.Lookup(username); err == nil {
|
||||
// Parse as signed integer first to handle negative UIDs
|
||||
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Skip negative UIDs (like nobody=-2 on macOS) and root
|
||||
if uid64 > 0 && uid64 != 0 {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no common users found, try to find any regular user with UID > 100
|
||||
// This helps on macOS where regular users start at UID 501
|
||||
allUsers := []string{"vma", "user", "test", "admin"}
|
||||
for _, username := range allUsers {
|
||||
if u, err := user.Lookup(username); err == nil {
|
||||
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if uid64 > 100 { // Regular user
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no common users found, return an error
|
||||
return nil, fmt.Errorf("no suitable non-root user found on this system")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
currentUID := uint32(os.Geteuid())
|
||||
|
||||
if currentUID == 0 {
|
||||
// When running as root, test that root can create commands for any user
|
||||
config := ExecutorConfig{
|
||||
UID: 1000, // Target non-root user
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/tmp",
|
||||
Shell: "/bin/sh",
|
||||
Command: "echo test",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
assert.NoError(t, err, "Root should be able to create commands for any user")
|
||||
assert.NotNil(t, cmd)
|
||||
} else {
|
||||
// When running as non-root, test that we can't drop to a different user
|
||||
config := ExecutorConfig{
|
||||
UID: 0, // Try to target root
|
||||
GID: 0,
|
||||
Groups: []uint32{0},
|
||||
WorkingDir: "/tmp",
|
||||
Shell: "/bin/sh",
|
||||
Command: "echo test",
|
||||
}
|
||||
|
||||
_, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot drop privileges")
|
||||
}
|
||||
}
|
||||
566
client/ssh/server/executor_windows.go
Normal file
566
client/ssh/server/executor_windows.go
Normal file
@@ -0,0 +1,566 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
ExitCodeSuccess = 0
|
||||
ExitCodeLogonFail = 10
|
||||
ExitCodeCreateProcessFail = 11
|
||||
ExitCodeWorkingDirFail = 12
|
||||
ExitCodeShellExecFail = 13
|
||||
ExitCodeValidationFail = 14
|
||||
)
|
||||
|
||||
type WindowsExecutorConfig struct {
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Interactive bool
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
}
|
||||
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
var (
|
||||
advapi32 = windows.NewLazyDLL("advapi32.dll")
|
||||
procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId")
|
||||
)
|
||||
|
||||
const (
|
||||
logon32LogonNetwork = 3 // Network logon - no password required for authenticated users
|
||||
|
||||
// Common error messages
|
||||
commandFlag = "-Command"
|
||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
||||
convertUsernameError = "convert username to UTF16: %w"
|
||||
convertDomainError = "convert domain to UTF16: %w"
|
||||
)
|
||||
|
||||
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, error) {
|
||||
if config.Username == "" {
|
||||
return nil, errors.New("username cannot be empty")
|
||||
}
|
||||
if config.Shell == "" {
|
||||
return nil, errors.New("shell cannot be empty")
|
||||
}
|
||||
|
||||
shell := config.Shell
|
||||
|
||||
var shellArgs []string
|
||||
if config.Command != "" {
|
||||
shellArgs = []string{shell, commandFlag, config.Command}
|
||||
} else {
|
||||
shellArgs = []string{shell}
|
||||
}
|
||||
|
||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
|
||||
cmd, err := pd.CreateWindowsProcessAsUser(
|
||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create Windows process as user: %w", err)
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// StatusSuccess represents successful LSA operation
|
||||
StatusSuccess = 0
|
||||
|
||||
// KerbS4ULogonType message type for domain users with Kerberos
|
||||
KerbS4ULogonType = 12
|
||||
// Msv10s4ulogontype message type for local users with MSV1_0
|
||||
Msv10s4ulogontype = 12
|
||||
|
||||
// MicrosoftKerberosNameA is the authentication package name for Kerberos
|
||||
MicrosoftKerberosNameA = "Kerberos"
|
||||
// Msv10packagename is the authentication package name for MSV1_0
|
||||
Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"
|
||||
|
||||
NameSamCompatible = 2
|
||||
NameUserPrincipal = 8
|
||||
NameCanonical = 7
|
||||
|
||||
maxUPNLen = 1024
|
||||
)
|
||||
|
||||
// kerbS4ULogon structure for S4U authentication (domain users)
|
||||
type kerbS4ULogon struct {
|
||||
MessageType uint32
|
||||
Flags uint32
|
||||
ClientUpn unicodeString
|
||||
ClientRealm unicodeString
|
||||
}
|
||||
|
||||
// msv10s4ulogon structure for S4U authentication (local users)
|
||||
type msv10s4ulogon struct {
|
||||
MessageType uint32
|
||||
Flags uint32
|
||||
UserPrincipalName unicodeString
|
||||
DomainName unicodeString
|
||||
}
|
||||
|
||||
// unicodeString structure
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer *uint16
|
||||
}
|
||||
|
||||
// lsaString structure
|
||||
type lsaString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer *byte
|
||||
}
|
||||
|
||||
// tokenSource structure
|
||||
type tokenSource struct {
|
||||
SourceName [8]byte
|
||||
SourceIdentifier windows.LUID
|
||||
}
|
||||
|
||||
// quotaLimits structure
|
||||
type quotaLimits struct {
|
||||
PagedPoolLimit uint32
|
||||
NonPagedPoolLimit uint32
|
||||
MinimumWorkingSetSize uint32
|
||||
MaximumWorkingSetSize uint32
|
||||
PagefileLimit uint32
|
||||
TimeLimit int64
|
||||
}
|
||||
|
||||
var (
|
||||
secur32 = windows.NewLazyDLL("secur32.dll")
|
||||
procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess")
|
||||
procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage")
|
||||
procLsaLogonUser = secur32.NewProc("LsaLogonUser")
|
||||
procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer")
|
||||
procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess")
|
||||
procTranslateNameW = secur32.NewProc("TranslateNameW")
|
||||
)
|
||||
|
||||
// newLsaString creates an LsaString from a Go string
|
||||
func newLsaString(s string) lsaString {
|
||||
b := append([]byte(s), 0)
|
||||
return lsaString{
|
||||
Length: uint16(len(s)),
|
||||
MaximumLength: uint16(len(b)),
|
||||
Buffer: &b[0],
|
||||
}
|
||||
}
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
// Use proper domain detection logic instead of simple string check
|
||||
pd := NewPrivilegeDropper()
|
||||
isDomainUser := !pd.isLocalUser(domain)
|
||||
|
||||
lsaHandle, err := initializeLsaConnection()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cleanupLsaConnection(lsaHandle)
|
||||
|
||||
authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
}
|
||||
|
||||
// buildUserCpn constructs the user principal name
|
||||
func buildUserCpn(username, domain string) string {
|
||||
if domain != "" && domain != "." {
|
||||
return fmt.Sprintf(`%s\%s`, domain, username)
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// initializeLsaConnection establishes connection to LSA
|
||||
func initializeLsaConnection() (windows.Handle, error) {
|
||||
|
||||
processName := newLsaString("NetBird")
|
||||
var mode uint32
|
||||
var lsaHandle windows.Handle
|
||||
ret, _, _ := procLsaRegisterLogonProcess.Call(
|
||||
uintptr(unsafe.Pointer(&processName)),
|
||||
uintptr(unsafe.Pointer(&lsaHandle)),
|
||||
uintptr(unsafe.Pointer(&mode)),
|
||||
)
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret)
|
||||
}
|
||||
|
||||
return lsaHandle, nil
|
||||
}
|
||||
|
||||
// cleanupLsaConnection closes the LSA connection
|
||||
func cleanupLsaConnection(lsaHandle windows.Handle) {
|
||||
if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess {
|
||||
log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupAuthenticationPackage finds the correct authentication package
|
||||
func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) {
|
||||
var authPackageName lsaString
|
||||
if isDomainUser {
|
||||
authPackageName = newLsaString(MicrosoftKerberosNameA)
|
||||
} else {
|
||||
authPackageName = newLsaString(Msv10packagename)
|
||||
}
|
||||
|
||||
var authPackageId uint32
|
||||
ret, _, _ := procLsaLookupAuthenticationPackage.Call(
|
||||
uintptr(lsaHandle),
|
||||
uintptr(unsafe.Pointer(&authPackageName)),
|
||||
uintptr(unsafe.Pointer(&authPackageId)),
|
||||
)
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret)
|
||||
}
|
||||
|
||||
return authPackageId, nil
|
||||
}
|
||||
|
||||
// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format)
|
||||
func lookupPrincipalName(username, domain string) (string, error) {
|
||||
samAccountName := fmt.Sprintf(`%s\%s`, domain, username)
|
||||
samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
upnBuf := make([]uint16, maxUPNLen+1)
|
||||
upnSize := uint32(len(upnBuf))
|
||||
|
||||
ret, _, _ := procTranslateNameW.Call(
|
||||
uintptr(unsafe.Pointer(samAccountNameUtf16)),
|
||||
uintptr(NameSamCompatible),
|
||||
uintptr(NameUserPrincipal),
|
||||
uintptr(unsafe.Pointer(&upnBuf[0])),
|
||||
uintptr(unsafe.Pointer(&upnSize)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
upn := windows.UTF16ToString(upnBuf[:upnSize])
|
||||
log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn)
|
||||
return upn, nil
|
||||
}
|
||||
|
||||
upnSize = uint32(len(upnBuf))
|
||||
ret, _, _ = procTranslateNameW.Call(
|
||||
uintptr(unsafe.Pointer(samAccountNameUtf16)),
|
||||
uintptr(NameSamCompatible),
|
||||
uintptr(NameCanonical),
|
||||
uintptr(unsafe.Pointer(&upnBuf[0])),
|
||||
uintptr(unsafe.Pointer(&upnSize)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
canonical := windows.UTF16ToString(upnBuf[:upnSize])
|
||||
slashIdx := strings.IndexByte(canonical, '/')
|
||||
if slashIdx > 0 {
|
||||
fqdn := canonical[:slashIdx]
|
||||
upn := fmt.Sprintf("%s@%s", username, fqdn)
|
||||
log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical)
|
||||
return upn, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName)
|
||||
return samAccountName, nil
|
||||
}
|
||||
|
||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
if isDomainUser {
|
||||
return prepareDomainS4ULogon(username, domain)
|
||||
}
|
||||
return prepareLocalS4ULogon(username)
|
||||
}
|
||||
|
||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
upn, err := lookupPrincipalName(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
|
||||
upnUtf16, err := windows.UTF16FromString(upn)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertUsernameError, err)
|
||||
}
|
||||
|
||||
structSize := unsafe.Sizeof(kerbS4ULogon{})
|
||||
upnByteSize := len(upnUtf16) * 2
|
||||
logonInfoSize := structSize + uintptr(upnByteSize)
|
||||
|
||||
buffer := make([]byte, logonInfoSize)
|
||||
logonInfo := unsafe.Pointer(&buffer[0])
|
||||
|
||||
s4uLogon := (*kerbS4ULogon)(logonInfo)
|
||||
s4uLogon.MessageType = KerbS4ULogonType
|
||||
s4uLogon.Flags = 0
|
||||
|
||||
upnOffset := structSize
|
||||
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
|
||||
copy((*[512]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
|
||||
|
||||
s4uLogon.ClientUpn = unicodeString{
|
||||
Length: uint16((len(upnUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(upnUtf16) * 2),
|
||||
Buffer: upnBuffer,
|
||||
}
|
||||
s4uLogon.ClientRealm = unicodeString{}
|
||||
|
||||
return logonInfo, logonInfoSize, nil
|
||||
}
|
||||
|
||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
|
||||
usernameUtf16, err := windows.UTF16FromString(username)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertUsernameError, err)
|
||||
}
|
||||
|
||||
domainUtf16, err := windows.UTF16FromString(".")
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertDomainError, err)
|
||||
}
|
||||
|
||||
structSize := unsafe.Sizeof(msv10s4ulogon{})
|
||||
usernameByteSize := len(usernameUtf16) * 2
|
||||
domainByteSize := len(domainUtf16) * 2
|
||||
logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize)
|
||||
|
||||
buffer := make([]byte, logonInfoSize)
|
||||
logonInfo := unsafe.Pointer(&buffer[0])
|
||||
|
||||
s4uLogon := (*msv10s4ulogon)(logonInfo)
|
||||
s4uLogon.MessageType = Msv10s4ulogontype
|
||||
s4uLogon.Flags = 0x0
|
||||
|
||||
usernameOffset := structSize
|
||||
usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset))
|
||||
copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16)
|
||||
|
||||
s4uLogon.UserPrincipalName = unicodeString{
|
||||
Length: uint16((len(usernameUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(usernameUtf16) * 2),
|
||||
Buffer: usernameBuffer,
|
||||
}
|
||||
|
||||
domainOffset := usernameOffset + uintptr(usernameByteSize)
|
||||
domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset))
|
||||
copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16)
|
||||
|
||||
s4uLogon.DomainName = unicodeString{
|
||||
Length: uint16((len(domainUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(domainUtf16) * 2),
|
||||
Buffer: domainBuffer,
|
||||
}
|
||||
|
||||
return logonInfo, logonInfoSize, nil
|
||||
}
|
||||
|
||||
// performS4ULogon executes the S4U logon operation
|
||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
var tokenSource tokenSource
|
||||
copy(tokenSource.SourceName[:], "netbird")
|
||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||
log.Debugf("AllocateLocallyUniqueId failed")
|
||||
}
|
||||
|
||||
originName := newLsaString("netbird")
|
||||
|
||||
var profile uintptr
|
||||
var profileSize uint32
|
||||
var logonId windows.LUID
|
||||
var token windows.Handle
|
||||
var quotas quotaLimits
|
||||
var subStatus int32
|
||||
|
||||
ret, _, _ := procLsaLogonUser.Call(
|
||||
uintptr(lsaHandle),
|
||||
uintptr(unsafe.Pointer(&originName)),
|
||||
logon32LogonNetwork,
|
||||
uintptr(authPackageId),
|
||||
uintptr(logonInfo),
|
||||
logonInfoSize,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&tokenSource)),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
uintptr(unsafe.Pointer(&profileSize)),
|
||||
uintptr(unsafe.Pointer(&logonId)),
|
||||
uintptr(unsafe.Pointer(&token)),
|
||||
uintptr(unsafe.Pointer("as)),
|
||||
uintptr(unsafe.Pointer(&subStatus)),
|
||||
)
|
||||
|
||||
if profile != 0 {
|
||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||
}
|
||||
|
||||
log.Debugf("created S4U %s token for user %s",
|
||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// createToken implements NetBird trust-based authentication using S4U
|
||||
func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) {
|
||||
fullUsername := buildUserCpn(username, domain)
|
||||
|
||||
if err := userExists(fullUsername, username, domain); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
isLocalUser := pd.isLocalUser(domain)
|
||||
|
||||
if isLocalUser {
|
||||
return pd.authenticateLocalUser(username, fullUsername)
|
||||
}
|
||||
return pd.authenticateDomainUser(username, domain, fullUsername)
|
||||
}
|
||||
|
||||
// userExists checks if the target useVerifier exists on the system
|
||||
func userExists(fullUsername, username, domain string) error {
|
||||
if _, err := lookupUser(fullUsername); err != nil {
|
||||
log.Debugf("User %s not found: %v", fullUsername, err)
|
||||
if domain != "" && domain != "." {
|
||||
_, err = lookupUser(username)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("target user %s not found: %w", fullUsername, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLocalUser determines if this is a local user vs domain user
|
||||
func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "localhost"
|
||||
}
|
||||
|
||||
return domain == "" || domain == "." ||
|
||||
strings.EqualFold(domain, hostname)
|
||||
}
|
||||
|
||||
// authenticateLocalUser handles authentication for local users
|
||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, ".")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// authenticateDomainUser handles authentication for domain users
|
||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||
}
|
||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables)
|
||||
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, error) {
|
||||
fullUsername := buildUserCpn(username, domain)
|
||||
|
||||
token, err := pd.createToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user authentication: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("using S4U authentication for user %s", fullUsername)
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Debugf("close impersonation token error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
|
||||
}
|
||||
|
||||
// createProcessWithToken creates process with the specified token and executable path
|
||||
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, error) {
|
||||
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
|
||||
cmd.Dir = workingDir
|
||||
|
||||
// Duplicate the token to create a primary token that can be used to start a new process
|
||||
var primaryToken windows.Token
|
||||
err := windows.DuplicateTokenEx(
|
||||
sourceToken,
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityIdentification,
|
||||
windows.TokenPrimary,
|
||||
&primaryToken,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("duplicate token to primary token: %w", err)
|
||||
}
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Token: syscall.Token(primaryToken),
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
return nil, fmt.Errorf("su command not available on Windows")
|
||||
}
|
||||
619
client/ssh/server/jwt_test.go
Normal file
619
client/ssh/server/jwt_test.go
Normal file
@@ -0,0 +1,619 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT enforcement tests in short mode")
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
}
|
||||
t.Logf("Detected server type: %s", serverType)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
|
||||
})
|
||||
|
||||
t.Run("allows_when_disabled", func(t *testing.T) {
|
||||
serverConfigNoJWT := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
serverNoJWT := New(serverConfigNoJWT)
|
||||
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
assert.False(t, serverType.RequiresJWT())
|
||||
|
||||
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64RawURLEncode(n),
|
||||
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func base64RawURLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// generateValidJWT creates a valid JWT token for testing
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
|
||||
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
|
||||
t.Helper()
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
ctx := context.Background()
|
||||
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
|
||||
func TestJWTDetection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT detection test in short mode")
|
||||
}
|
||||
|
||||
jwksServer, _, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
assert.True(t, serverType.RequiresJWT())
|
||||
}
|
||||
|
||||
func TestJWTFailClose(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT fail-close tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenClaims jwt.MapClaims
|
||||
}{
|
||||
{
|
||||
name: "blocks_token_missing_iat",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_sub",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_iss",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_aud",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_issuer",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": "wrong-issuer",
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_audience",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": "wrong-audience",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_expired_token",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(tokenString),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
assert.Error(t, err, "Authentication should fail (fail-close)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
|
||||
func TestJWTAuthentication(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT authentication tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantAuthOK bool
|
||||
setupServer func(*Server)
|
||||
testOperation func(*testing.T, *cryptossh.Client, string) error
|
||||
wantOpSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "allows_shell_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
return session.Shell()
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_invalid_token",
|
||||
token: "invalid",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_shell_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_command_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("ls")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_sftp_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
return session.RequestSubsystem("sftp")
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_sftp_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
err = session.RequestSubsystem("sftp")
|
||||
if err == nil {
|
||||
err = session.Wait()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_port_forward_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowRemotePortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_port_forward_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowLocalPortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// TODO: Skip port forwarding tests on Windows - user switching not supported
|
||||
// These features are tested on Linux/Unix platforms
|
||||
if runtime.GOOS == "windows" &&
|
||||
(tc.name == "allows_port_forward_with_jwt" ||
|
||||
tc.name == "blocks_port_forward_without_jwt") {
|
||||
t.Skip("Skipping port forwarding test on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
if tc.setupServer != nil {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
var authMethods []cryptossh.AuthMethod
|
||||
if tc.token == "valid" {
|
||||
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
}
|
||||
} else if tc.token == "invalid" {
|
||||
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(invalidToken),
|
||||
}
|
||||
}
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed")
|
||||
} else if err != nil {
|
||||
t.Logf("Connection failed as expected: %v", err)
|
||||
return
|
||||
}
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err = tc.testOperation(t, conn, serverAddr)
|
||||
if tc.wantOpSuccess {
|
||||
require.NoError(t, err, "Operation should succeed")
|
||||
} else {
|
||||
assert.Error(t, err, "Operation should fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
386
client/ssh/server/port_forwarding.go
Normal file
386
client/ssh/server/port_forwarding.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SessionKey uniquely identifies an SSH session
|
||||
type SessionKey string
|
||||
|
||||
// ConnectionKey uniquely identifies a port forwarding connection within a session
|
||||
type ConnectionKey string
|
||||
|
||||
// ForwardKey uniquely identifies a port forwarding listener
|
||||
type ForwardKey string
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
// SetAllowLocalPortForwarding configures local port forwarding
|
||||
func (s *Server) SetAllowLocalPortForwarding(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowLocalPortForwarding = allow
|
||||
}
|
||||
|
||||
// SetAllowRemotePortForwarding configures remote port forwarding
|
||||
func (s *Server) SetAllowRemotePortForwarding(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowRemotePortForwarding = allow
|
||||
}
|
||||
|
||||
// configurePortForwarding sets up port forwarding callbacks
|
||||
func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
allowLocal := s.allowLocalPortForwarding
|
||||
allowRemote := s.allowRemotePortForwarding
|
||||
|
||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
|
||||
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
|
||||
return true
|
||||
}
|
||||
|
||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if !allowRemote {
|
||||
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
|
||||
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
|
||||
}
|
||||
|
||||
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
|
||||
// Returns nil if allowed, error if denied.
|
||||
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
|
||||
}
|
||||
|
||||
username := ctx.User()
|
||||
remoteAddr := "unknown"
|
||||
if ctx.RemoteAddr() != nil {
|
||||
remoteAddr = ctx.RemoteAddr().String()
|
||||
}
|
||||
|
||||
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
|
||||
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: false,
|
||||
FeatureName: forwardType + " port forwarding",
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
|
||||
forwardType, result.User.Username, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
|
||||
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
if !s.isRemotePortForwardingAllowed() {
|
||||
logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
payload, err := s.parseTcpipForwardRequest(req)
|
||||
if err != nil {
|
||||
logger.Errorf("tcpip-forward unmarshal error: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
|
||||
logger.Warnf("tcpip-forward denied: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
sshConn, err := s.getSSHConnection(ctx)
|
||||
if err != nil {
|
||||
logger.Warnf("tcpip-forward request denied: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return s.setupDirectForward(ctx, logger, sshConn, payload)
|
||||
}
|
||||
|
||||
// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
|
||||
func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
var payload tcpipForwardMsg
|
||||
if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
|
||||
logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
if s.removeRemoteForwardListener(key) {
|
||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
|
||||
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
|
||||
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
|
||||
|
||||
defer func() {
|
||||
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("remote forward listener close error: %v", err)
|
||||
} else {
|
||||
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
|
||||
}
|
||||
}()
|
||||
|
||||
acceptChan := make(chan acceptResult, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
select {
|
||||
case acceptChan <- acceptResult{conn: conn, err: err}:
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case result := <-acceptChan:
|
||||
if result.err != nil {
|
||||
log.Debugf("remote forward accept error: %v", result.err)
|
||||
return
|
||||
}
|
||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||
case <-ctx.Done():
|
||||
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestLogger creates a logger with user and remote address context
|
||||
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
|
||||
remoteAddr := "unknown"
|
||||
username := "unknown"
|
||||
if ctx != nil {
|
||||
if ctx.RemoteAddr() != nil {
|
||||
remoteAddr = ctx.RemoteAddr().String()
|
||||
}
|
||||
username = ctx.User()
|
||||
}
|
||||
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
|
||||
}
|
||||
|
||||
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
|
||||
func (s *Server) isRemotePortForwardingAllowed() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// parseTcpipForwardRequest parses the SSH request payload
|
||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||
var payload tcpipForwardMsg
|
||||
err := cryptossh.Unmarshal(req.Payload, &payload)
|
||||
return &payload, err
|
||||
}
|
||||
|
||||
// getSSHConnection extracts SSH connection from context
|
||||
func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
|
||||
if ctx == nil {
|
||||
return nil, fmt.Errorf("no context")
|
||||
}
|
||||
sshConnValue := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConnValue == nil {
|
||||
return nil, fmt.Errorf("no SSH connection in context")
|
||||
}
|
||||
sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
|
||||
if !ok || sshConn == nil {
|
||||
return nil, fmt.Errorf("invalid SSH connection in context")
|
||||
}
|
||||
return sshConn, nil
|
||||
}
|
||||
|
||||
// setupDirectForward sets up a direct port forward
|
||||
func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
|
||||
bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
|
||||
|
||||
ln, err := net.Listen("tcp", bindAddr)
|
||||
if err != nil {
|
||||
logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
actualPort := payload.Port
|
||||
if payload.Port == 0 {
|
||||
tcpAddr := ln.Addr().(*net.TCPAddr)
|
||||
actualPort = uint32(tcpAddr.Port)
|
||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||
}
|
||||
|
||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
s.storeRemoteForwardListener(key, ln)
|
||||
|
||||
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
|
||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||
|
||||
response := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(response, actualPort)
|
||||
|
||||
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
|
||||
return true, response
|
||||
}
|
||||
|
||||
// acceptResult holds the result of a listener Accept() call
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
||||
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
||||
sessionKey := s.findSessionKeyByContext(ctx)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"session": sessionKey,
|
||||
"conn": connID,
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
|
||||
if sshConn == nil {
|
||||
logger.Debugf("remote forward: no SSH connection in context")
|
||||
return
|
||||
}
|
||||
|
||||
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.proxyForwardConnection(ctx, logger, conn, channel)
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
|
||||
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
|
||||
|
||||
payload := struct {
|
||||
ConnectedAddress string
|
||||
ConnectedPort uint32
|
||||
OriginatorAddress string
|
||||
OriginatorPort uint32
|
||||
}{
|
||||
ConnectedAddress: host,
|
||||
ConnectedPort: port,
|
||||
OriginatorAddress: remoteAddr.IP.String(),
|
||||
OriginatorPort: uint32(remoteAddr.Port),
|
||||
}
|
||||
|
||||
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open SSH channel: %w", err)
|
||||
}
|
||||
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
|
||||
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, conn); err != nil {
|
||||
logger.Debugf("copy error (conn->channel): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn, channel); err != nil {
|
||||
logger.Debugf("copy error (channel->conn): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
// First copy finished, wait for second copy or context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}
|
||||
649
client/ssh/server/server.go
Normal file
649
client/ssh/server/server.go
Normal file
@@ -0,0 +1,649 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 22
|
||||
|
||||
// InternalSSHPort is the port SSH server listens on and is redirected to
|
||||
const InternalSSHPort = 22022
|
||||
|
||||
const (
|
||||
errWriteSession = "write session error: %v"
|
||||
errExitSession = "exit session error: %v"
|
||||
|
||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||
|
||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||
DefaultJWTMaxTokenAge = 5 * 60
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
// PrivilegedUserError represents an error when privileged user login is disabled
|
||||
type PrivilegedUserError struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
func (e *PrivilegedUserError) Error() string {
|
||||
return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username)
|
||||
}
|
||||
|
||||
func (e *PrivilegedUserError) Is(target error) bool {
|
||||
return target == ErrPrivilegedUserDisabled
|
||||
}
|
||||
|
||||
// UserNotFoundError represents an error when a user cannot be found
|
||||
type UserNotFoundError struct {
|
||||
Username string
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("user %s not found", e.Username)
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Is(target error) bool {
|
||||
return target == ErrUserNotFound
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors
|
||||
func logSessionExitError(logger *log.Entry, err error) {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Warnf(errExitSession, err)
|
||||
}
|
||||
}
|
||||
|
||||
// safeLogCommand returns a safe representation of the command for logging
|
||||
func safeLogCommand(cmd []string) string {
|
||||
if len(cmd) == 0 {
|
||||
return "<interactive shell>"
|
||||
}
|
||||
if len(cmd) == 1 {
|
||||
return cmd[0]
|
||||
}
|
||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||
}
|
||||
|
||||
type sshConnectionState struct {
|
||||
hasActivePortForward bool
|
||||
username string
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sessions map[SessionKey]ssh.Session
|
||||
sessionCancels map[ConnectionKey]context.CancelFunc
|
||||
|
||||
allowLocalPortForwarding bool
|
||||
allowRemotePortForwarding bool
|
||||
allowRootLogin bool
|
||||
allowSFTP bool
|
||||
jwtEnabled bool
|
||||
|
||||
netstackNet *netstack.Net
|
||||
|
||||
wgAddress wgaddr.Address
|
||||
|
||||
remoteForwardListeners map[ForwardKey]net.Listener
|
||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||
|
||||
jwtValidator *jwt.Validator
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
type Config struct {
|
||||
// JWT authentication configuration. If nil, JWT authentication is disabled
|
||||
JWT *JWTConfig
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
sessions: make(map[SessionKey]ssh.Session),
|
||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start runs the SSH server
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer != nil {
|
||||
return errors.New("SSH server is already running")
|
||||
}
|
||||
|
||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create listener: %w", err)
|
||||
}
|
||||
|
||||
sshServer, err := s.createSSHServer(ln.Addr())
|
||||
if err != nil {
|
||||
s.closeListener(ln)
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
s.sshServer = sshServer
|
||||
log.Infof("SSH server started on %s", addrDesc)
|
||||
|
||||
go func() {
|
||||
if err := sshServer.Serve(ln); !isShutdownError(err) {
|
||||
log.Errorf("SSH server error: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("listen on netstack: %w", err)
|
||||
}
|
||||
return ln, fmt.Sprintf("netstack %s", addr), nil
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addr)
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", tcpAddr.String())
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
return ln, addr.String(), nil
|
||||
}
|
||||
|
||||
func (s *Server) closeListener(ln net.Listener) {
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("listener close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.sshServer.Close(); err != nil && !isShutdownError(err) {
|
||||
return fmt.Errorf("shutdown SSH server: %w", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = net
|
||||
}
|
||||
|
||||
// SetNetworkValidation configures network-based connection filtering
|
||||
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
s.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.jwtValidator = validator
|
||||
s.jwtExtractor = extractor
|
||||
|
||||
log.Infof("JWT validator initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
s.mu.RLock()
|
||||
jwtValidator := s.jwtValidator
|
||||
jwtConfig := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtValidator == nil {
|
||||
return nil, fmt.Errorf("JWT validator not initialized")
|
||||
}
|
||||
|
||||
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.checkTokenAge(token, jwtConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
s.mu.RLock()
|
||||
jwtExtractor := s.jwtExtractor
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtExtractor == nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
|
||||
}
|
||||
|
||||
userAuth, err := jwtExtractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
|
||||
}
|
||||
|
||||
if !s.hasSSHAccess(&userAuth) {
|
||||
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
|
||||
}
|
||||
|
||||
return &userAuth, nil
|
||||
}
|
||||
|
||||
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
return userAuth.UserId != ""
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if state, exists := s.sshConnections[sshConn]; exists {
|
||||
state.hasActivePortForward = true
|
||||
} else {
|
||||
s.sshConnections[sshConn] = &sshConnectionState{
|
||||
hasActivePortForward: true,
|
||||
username: username,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||
// We can't extract the SSH connection from net.Conn directly
|
||||
// Connection cleanup will happen during session cleanup or via timeout
|
||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
if ctx == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try to match by SSH connection
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConn == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Look through sessions to find one with matching connection
|
||||
for sessionKey, session := range s.sessions {
|
||||
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
||||
return sessionKey
|
||||
}
|
||||
}
|
||||
|
||||
// If no session found, this might be during early connection setup
|
||||
// Return a temporary key that we'll fix up later
|
||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||
return tempKey
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
s.mu.RLock()
|
||||
netbirdNetwork := s.wgAddress.Network
|
||||
localIP := s.wgAddress.IP
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
|
||||
return conn
|
||||
}
|
||||
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Block connections from our own IP (prevent local apps from connecting to ourselves)
|
||||
if remoteIP == localIP {
|
||||
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !netbirdNetwork.Contains(remoteIP) {
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
|
||||
return conn
|
||||
}
|
||||
|
||||
func isShutdownError(err error) bool {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Op == "accept" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
if err := enableUserSwitching(); err != nil {
|
||||
log.Warnf("failed to enable user switching: %v", err)
|
||||
}
|
||||
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
|
||||
if s.jwtEnabled {
|
||||
serverVersion += " " + detection.JWTRequiredMarker
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: addr.String(),
|
||||
Handler: s.sessionHandler,
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": s.sftpSubsystemHandler,
|
||||
},
|
||||
HostSigners: []ssh.Signer{},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": s.directTCPIPHandler,
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": s.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
|
||||
},
|
||||
ConnCallback: s.connectionValidator,
|
||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
if s.jwtEnabled {
|
||||
server.PasswordHandler = s.passwordHandler
|
||||
}
|
||||
|
||||
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
||||
if err := server.SetOption(hostKeyPEM); err != nil {
|
||||
return nil, fmt.Errorf("set host key: %w", err)
|
||||
}
|
||||
|
||||
s.configurePortForwarding(server)
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.remoteForwardListeners[key] = ln
|
||||
}
|
||||
|
||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ln, exists := s.remoteForwardListeners[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(s.remoteForwardListeners, key)
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("remote forward listener close error: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||
var payload struct {
|
||||
Host string
|
||||
Port uint32
|
||||
OriginatorAddr string
|
||||
OriginatorPort uint32
|
||||
}
|
||||
|
||||
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
|
||||
log.Debugf("channel reject error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
allowLocal := s.allowLocalPortForwarding
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Check privilege requirements for the destination port
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
394
client/ssh/server/server_config_test.go
Normal file
394
client/ssh/server/server_config_test.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
)
|
||||
|
||||
func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
username string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "root login allowed",
|
||||
allowRoot: true,
|
||||
username: "root",
|
||||
expectError: false,
|
||||
description: "Root login should succeed when allowed",
|
||||
},
|
||||
{
|
||||
name: "root login denied",
|
||||
allowRoot: false,
|
||||
username: "root",
|
||||
expectError: true,
|
||||
description: "Root login should fail when disabled",
|
||||
},
|
||||
{
|
||||
name: "regular user login always allowed",
|
||||
allowRoot: false,
|
||||
username: "testuser",
|
||||
expectError: false,
|
||||
description: "Regular user login should work regardless of root setting",
|
||||
},
|
||||
}
|
||||
|
||||
// Add Windows Administrator tests if on Windows
|
||||
if runtime.GOOS == "windows" {
|
||||
tests = append(tests, []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
username string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Administrator login allowed",
|
||||
allowRoot: true,
|
||||
username: "Administrator",
|
||||
expectError: false,
|
||||
description: "Administrator login should succeed when allowed",
|
||||
},
|
||||
{
|
||||
name: "Administrator login denied",
|
||||
allowRoot: false,
|
||||
username: "Administrator",
|
||||
expectError: true,
|
||||
description: "Administrator login should fail when disabled",
|
||||
},
|
||||
{
|
||||
name: "administrator login denied (lowercase)",
|
||||
allowRoot: false,
|
||||
username: "administrator",
|
||||
expectError: true,
|
||||
description: "administrator login should fail when disabled (case insensitive)",
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock privileged environment to test root access controls
|
||||
// Set up mock users based on platform
|
||||
mockUsers := map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
"testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"),
|
||||
}
|
||||
|
||||
// Add Windows-specific users for Administrator tests
|
||||
if runtime.GOOS == "windows" {
|
||||
mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator")
|
||||
mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator")
|
||||
}
|
||||
|
||||
cleanup := setupTestDependencies(
|
||||
createTestUser("root", "0", "0", "/root"), // Running as root
|
||||
nil,
|
||||
runtime.GOOS,
|
||||
0, // euid 0 (root)
|
||||
mockUsers,
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
// Create server with specific configuration
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(tt.allowRoot)
|
||||
|
||||
// Test the userNameLookup method directly
|
||||
user, err := server.userNameLookup(tt.username)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
|
||||
// Check for appropriate error message based on platform capabilities
|
||||
errorMsg := err.Error()
|
||||
// Either privileged user restriction OR user switching limitation
|
||||
hasPrivilegedError := strings.Contains(errorMsg, "privileged user")
|
||||
hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported")
|
||||
assert.True(t, hasPrivilegedError || hasSwitchingError,
|
||||
"Expected privileged user or user switching error, got: %s", errorMsg)
|
||||
}
|
||||
} else {
|
||||
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
|
||||
// For privileged users, we expect either success or a different error
|
||||
// (like user not found), but not the "login disabled" error
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "privileged user login is disabled")
|
||||
}
|
||||
} else {
|
||||
// For regular users, lookup should generally succeed or fall back gracefully
|
||||
// Note: may return current user as fallback
|
||||
assert.NotNil(t, user)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortForwardingRestriction(t *testing.T) {
|
||||
// Test that the port forwarding callbacks properly respect configuration flags
|
||||
// This is a unit test of the callback logic, not a full integration test
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowLocalForwarding bool
|
||||
allowRemoteForwarding bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "all forwarding allowed",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: true,
|
||||
description: "Both local and remote forwarding should be allowed",
|
||||
},
|
||||
{
|
||||
name: "local forwarding disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: true,
|
||||
description: "Local forwarding should be denied when disabled",
|
||||
},
|
||||
{
|
||||
name: "remote forwarding disabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: false,
|
||||
description: "Remote forwarding should be denied when disabled",
|
||||
},
|
||||
{
|
||||
name: "all forwarding disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: false,
|
||||
description: "Both forwarding types should be denied when disabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create server with specific configuration
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||
|
||||
// We need to access the internal configuration to simulate the callback tests
|
||||
// Since the callbacks are created inside the Start method, we'll test the logic directly
|
||||
|
||||
// Test the configuration values are set correctly
|
||||
server.mu.RLock()
|
||||
allowLocal := server.allowLocalPortForwarding
|
||||
allowRemote := server.allowRemotePortForwarding
|
||||
server.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly")
|
||||
assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly")
|
||||
|
||||
// Simulate the callback logic
|
||||
localResult := allowLocal // This would be the callback return value
|
||||
remoteResult := allowRemote // This would be the callback return value
|
||||
|
||||
assert.Equal(t, tt.allowLocalForwarding, localResult,
|
||||
"Local port forwarding callback should return correct value")
|
||||
assert.Equal(t, tt.allowRemoteForwarding, remoteResult,
|
||||
"Remote port forwarding callback should return correct value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortConflictHandling(t *testing.T) {
|
||||
// Test that multiple sessions requesting the same local port are handled naturally by the OS
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Get a free port for testing
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
testPort := ln.Addr().(*net.TCPAddr).Port
|
||||
err = ln.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect first client
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client1.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Connect second client
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client2.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// First client binds to the test port
|
||||
localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort)
|
||||
remoteAddr := "127.0.0.1:80"
|
||||
|
||||
// Start first client's port forwarding
|
||||
done1 := make(chan error, 1)
|
||||
go func() {
|
||||
// This should succeed and hold the port
|
||||
err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr)
|
||||
done1 <- err
|
||||
}()
|
||||
|
||||
// Give first client time to bind
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Second client tries to bind to same port
|
||||
localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort)
|
||||
|
||||
shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer shortCancel()
|
||||
|
||||
err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr)
|
||||
// Second client should fail due to "address already in use"
|
||||
assert.Error(t, err, "Second client should fail to bind to same port")
|
||||
if err != nil {
|
||||
// The error should indicate the address is already in use
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, errMsg, "only one usage of each socket address",
|
||||
"Error should indicate port conflict")
|
||||
} else {
|
||||
assert.Contains(t, errMsg, "address already in use",
|
||||
"Error should indicate port conflict")
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel first client's context and wait for it to finish
|
||||
cancel1()
|
||||
select {
|
||||
case err1 := <-done1:
|
||||
// Should get context cancelled or deadline exceeded
|
||||
assert.Error(t, err1, "First client should exit when context cancelled")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("First client did not exit within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_IsPrivilegedUser(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "root",
|
||||
expected: true,
|
||||
description: "root should be considered privileged",
|
||||
},
|
||||
{
|
||||
username: "regular",
|
||||
expected: false,
|
||||
description: "regular user should not be privileged",
|
||||
},
|
||||
{
|
||||
username: "",
|
||||
expected: false,
|
||||
description: "empty username should not be privileged",
|
||||
},
|
||||
}
|
||||
|
||||
// Add Windows-specific tests
|
||||
if runtime.GOOS == "windows" {
|
||||
tests = append(tests, []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "Administrator",
|
||||
expected: true,
|
||||
description: "Administrator should be considered privileged on Windows",
|
||||
},
|
||||
{
|
||||
username: "administrator",
|
||||
expected: true,
|
||||
description: "administrator should be considered privileged on Windows (case insensitive)",
|
||||
},
|
||||
}...)
|
||||
} else {
|
||||
// On non-Windows systems, Administrator should not be privileged
|
||||
tests = append(tests, []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "Administrator",
|
||||
expected: false,
|
||||
description: "Administrator should not be privileged on non-Windows systems",
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
result := isPrivilegedUsername(tt.username)
|
||||
assert.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
435
client/ssh/server/server_test.go
Normal file
435
client/ssh/server/server_test.go
Normal file
@@ -0,0 +1,435 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: key,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSSHServerIntegration(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with random port
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server in background
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Get a free port
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
// Server is ready when we get the started signal
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key for verification
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Create SSH client config
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// Connect to SSH server
|
||||
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("close client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test creating a session
|
||||
session, err := client.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
t.Logf("close session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Note: Since we don't have a real shell environment in tests,
|
||||
// we can't test actual command execution, but we can verify
|
||||
// the connection and authentication work
|
||||
t.Log("SSH connection and authentication successful")
|
||||
}
|
||||
|
||||
func TestSSHServerMultipleConnections(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
// Server is ready when we get the started signal
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// Test multiple concurrent connections
|
||||
const numConnections = 5
|
||||
results := make(chan error, numConnections)
|
||||
|
||||
for i := 0; i < numConnections; i++ {
|
||||
go func(id int) {
|
||||
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("connection %d failed: %w", id, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close() // Ignore error in test goroutine
|
||||
}()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("session %d failed: %w", id, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = session.Close() // Ignore error in test goroutine
|
||||
}()
|
||||
|
||||
results <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for i := 0; i < numConnections; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("Connection %d timed out", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
// Server is ready when we get the started signal
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Try to connect with client key
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(clientSigner),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||
if conn != nil {
|
||||
assert.NoError(t, conn.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServerStartStopCycle(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
// Test multiple start/stop cycles
|
||||
for i := 0; i < 3; i++ {
|
||||
t.Logf("Start/stop cycle %d", i+1)
|
||||
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("Cycle %d: Server start timeout", i+1)
|
||||
}
|
||||
|
||||
err = server.Stop()
|
||||
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Windows shell test in short mode")
|
||||
}
|
||||
|
||||
server := &Server{}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// Test Windows cmd.exe shell behavior
|
||||
args := server.getShellCommandArgs("cmd.exe", "echo test")
|
||||
assert.Equal(t, "cmd.exe", args[0])
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
|
||||
// Test PowerShell behavior
|
||||
args = server.getShellCommandArgs("powershell.exe", "echo test")
|
||||
assert.Equal(t, "powershell.exe", args[0])
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
} else {
|
||||
// Test Unix shell behavior
|
||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Equal(t, "-l", args[1])
|
||||
assert.Equal(t, "-c", args[2])
|
||||
assert.Equal(t, "echo test", args[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig1 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server1 := New(serverConfig1)
|
||||
|
||||
serverConfig2 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server2 := New(serverConfig2)
|
||||
|
||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||
|
||||
server2.SetAllowLocalPortForwarding(true)
|
||||
server2.SetAllowRemotePortForwarding(true)
|
||||
|
||||
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
|
||||
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
|
||||
}
|
||||
145
client/ssh/server/session_handlers.go
Normal file
145
client/ssh/server/session_handlers.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// sessionHandler handles SSH sessions
|
||||
func (s *Server) sessionHandler(session ssh.Session) {
|
||||
sessionKey := s.registerSession(session)
|
||||
logger := log.WithField("session", sessionKey)
|
||||
logger.Infof("SSH session started")
|
||||
sessionStart := time.Now()
|
||||
|
||||
defer s.unregisterSession(sessionKey, session)
|
||||
defer func() {
|
||||
duration := time.Since(sessionStart).Round(time.Millisecond)
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Warnf("close session after %v: %v", duration, err)
|
||||
}
|
||||
logger.Infof("SSH session closed after %v", duration)
|
||||
}()
|
||||
|
||||
privilegeResult, err := s.userPrivilegeCheck(session.User())
|
||||
if err != nil {
|
||||
s.handlePrivError(logger, session, err)
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
switch {
|
||||
case isPty && hasCommand:
|
||||
// ssh -t <host> <cmd> - Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
||||
case isPty:
|
||||
// ssh <host> - Pty interactive session (login)
|
||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
||||
case hasCommand:
|
||||
// ssh <host> <cmd> - non-Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, nil)
|
||||
default:
|
||||
s.rejectInvalidSession(logger, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
|
||||
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
|
||||
}
|
||||
|
||||
func (s *Server) registerSession(session ssh.Session) SessionKey {
|
||||
sessionID := session.Context().Value(ssh.ContextKeySessionID)
|
||||
if sessionID == nil {
|
||||
sessionID = fmt.Sprintf("%p", session)
|
||||
}
|
||||
|
||||
// Create a short 4-byte identifier from the full session ID
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
|
||||
hash := hasher.Sum(nil)
|
||||
shortID := hex.EncodeToString(hash[:4])
|
||||
|
||||
remoteAddr := session.RemoteAddr().String()
|
||||
username := session.User()
|
||||
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[sessionKey] = session
|
||||
s.mu.Unlock()
|
||||
|
||||
return sessionKey
|
||||
}
|
||||
|
||||
func (s *Server) unregisterSession(sessionKey SessionKey, _ ssh.Session) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, sessionKey)
|
||||
|
||||
// Cancel all port forwarding connections for this session
|
||||
var connectionsToCancel []ConnectionKey
|
||||
for key := range s.sessionCancels {
|
||||
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
|
||||
connectionsToCancel = append(connectionsToCancel, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range connectionsToCancel {
|
||||
if cancelFunc, exists := s.sessionCancels[key]; exists {
|
||||
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
|
||||
cancelFunc()
|
||||
delete(s.sessionCancels, key)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
|
||||
logger.Warnf("user privilege check failed: %v", err)
|
||||
|
||||
errorMsg := s.buildUserLookupErrorMessage(err)
|
||||
|
||||
if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if exitErr := session.Exit(1); exitErr != nil {
|
||||
logSessionExitError(logger, exitErr)
|
||||
}
|
||||
}
|
||||
|
||||
// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
|
||||
func (s *Server) buildUserLookupErrorMessage(err error) string {
|
||||
var privilegedErr *PrivilegedUserError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &privilegedErr):
|
||||
if privilegedErr.Username == "root" {
|
||||
return "root login is disabled on this SSH server\n"
|
||||
}
|
||||
return "privileged user access is disabled on this SSH server\n"
|
||||
|
||||
case errors.Is(err, ErrPrivilegeRequired):
|
||||
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
|
||||
|
||||
case errors.Is(err, ErrPrivilegedUserSwitch):
|
||||
return "Cannot switch to privileged user - current user lacks required privileges\n"
|
||||
|
||||
default:
|
||||
return "User authentication failed\n"
|
||||
}
|
||||
}
|
||||
22
client/ssh/server/session_handlers_js.go
Normal file
22
client/ssh/server/session_handlers_js.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handlePty is not supported on JS/WASM
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
81
client/ssh/server/sftp.go
Normal file
81
client/ssh/server/sftp.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetAllowSFTP enables or disables SFTP support
|
||||
func (s *Server) SetAllowSFTP(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowSFTP = allow
|
||||
}
|
||||
|
||||
// sftpSubsystemHandler handles SFTP subsystem requests
|
||||
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
|
||||
s.mu.RLock()
|
||||
allowSFTP := s.allowSFTP
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowSFTP {
|
||||
log.Debugf("SFTP subsystem request denied: SFTP disabled")
|
||||
if err := sess.Exit(1); err != nil {
|
||||
log.Debugf("SFTP session exit failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: sess.User(),
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSFTP,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
|
||||
if err := sess.Exit(1); err != nil {
|
||||
log.Debugf("exit SFTP session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
|
||||
|
||||
if !result.RequiresUserSwitching {
|
||||
if err := s.executeSftpDirect(sess); err != nil {
|
||||
log.Errorf("SFTP direct execution: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
|
||||
log.Errorf("SFTP privilege drop execution: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// executeSftpDirect executes SFTP directly without privilege dropping
|
||||
func (s *Server) executeSftpDirect(sess ssh.Session) error {
|
||||
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
|
||||
|
||||
sftpServer, err := sftp.NewServer(sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SFTP server creation: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("failed to close sftp server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := sftpServer.Serve(); err != nil && err != io.EOF {
|
||||
return fmt.Errorf("serve: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
12
client/ssh/server/sftp_js.go
Normal file
12
client/ssh/server/sftp_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
)
|
||||
|
||||
// parseUserCredentials is not supported on JS/WASM
|
||||
func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) {
|
||||
return 0, 0, nil, errNotSupported
|
||||
}
|
||||
222
client/ssh/server/sftp_test.go
Normal file
222
client/ssh/server/sftp_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
||||
// Skip SFTP test when running as root due to protocol issues in some environments
|
||||
if os.Geteuid() == 0 {
|
||||
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
|
||||
}
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP enabled
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(true)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// (currentUser already obtained at function start)
|
||||
|
||||
// Create SSH client connection
|
||||
clientConfig := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
||||
require.NoError(t, err, "SSH connection should succeed")
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create SFTP client
|
||||
sftpClient, err := sftp.NewClient(conn)
|
||||
require.NoError(t, err, "SFTP client creation should succeed")
|
||||
defer func() {
|
||||
if err := sftpClient.Close(); err != nil {
|
||||
t.Logf("SFTP client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test basic SFTP operations
|
||||
workingDir, err := sftpClient.Getwd()
|
||||
assert.NoError(t, err, "Should be able to get working directory")
|
||||
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
|
||||
|
||||
// Test directory listing
|
||||
files, err := sftpClient.ReadDir(".")
|
||||
assert.NoError(t, err, "Should be able to list current directory")
|
||||
assert.NotNil(t, files, "File list should not be nil")
|
||||
}
|
||||
|
||||
func TestSSHServer_SFTPDisabled(t *testing.T) {
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP disabled
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(false)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// (currentUser already obtained at function start)
|
||||
|
||||
// Create SSH client connection
|
||||
clientConfig := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
||||
require.NoError(t, err, "SSH connection should succeed")
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Try to create SFTP client - should fail when SFTP is disabled
|
||||
_, err = sftp.NewClient(conn)
|
||||
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
|
||||
}
|
||||
71
client/ssh/server/sftp_unix.go
Normal file
71
client/ssh/server/sftp_unix.go
Normal file
@@ -0,0 +1,71 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping
|
||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||
uid, gid, groups, err := s.parseUserCredentials(targetUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
|
||||
sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create executor: %w", err)
|
||||
}
|
||||
|
||||
sftpCmd.Stdin = sess
|
||||
sftpCmd.Stdout = sess
|
||||
sftpCmd.Stderr = sess.Stderr()
|
||||
|
||||
log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid)
|
||||
|
||||
if err := sftpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting SFTP executor: %w", err)
|
||||
}
|
||||
|
||||
if err := sftpCmd.Wait(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
log.Tracef("SFTP process exited with code %d", exitError.ExitCode())
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping
|
||||
func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) {
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "sftp",
|
||||
"--uid", strconv.FormatUint(uint64(uid), 10),
|
||||
"--gid", strconv.FormatUint(uint64(gid), 10),
|
||||
"--working-dir", workingDir,
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
args = append(args, "--groups", strconv.FormatUint(uint64(group), 10))
|
||||
}
|
||||
|
||||
log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args)
|
||||
return exec.CommandContext(sess.Context(), netbirdPath, args...), nil
|
||||
}
|
||||
85
client/ssh/server/sftp_windows.go
Normal file
85
client/ssh/server/sftp_windows.go
Normal file
@@ -0,0 +1,85 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// createSftpCommand creates a Windows SFTP command with user switching
|
||||
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, error) {
|
||||
username, domain := s.parseUsername(targetUser.Username)
|
||||
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get netbird executable path: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "sftp",
|
||||
"--working-dir", targetUser.HomeDir,
|
||||
"--windows-username", username,
|
||||
"--windows-domain", domain,
|
||||
}
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
token, err := pd.createToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create token: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Warnf("failed to close Windows token handle: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmd, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SFTP command: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
|
||||
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd) error {
|
||||
sftpCmd.Stdin = sess
|
||||
sftpCmd.Stdout = sess
|
||||
sftpCmd.Stderr = sess.Stderr()
|
||||
|
||||
if err := sftpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting sftp executor: %w", err)
|
||||
}
|
||||
|
||||
if err := sftpCmd.Wait(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
log.Tracef("sftp process exited with code %d", exitError.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("exec sftp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
|
||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||
sftpCmd, err := s.createSftpCommand(targetUser, sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sftp: %w", err)
|
||||
}
|
||||
return s.executeSftpCommand(sess, sftpCmd)
|
||||
}
|
||||
175
client/ssh/server/shell.go
Normal file
175
client/ssh/server/shell.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUnixShell = "/bin/sh"
|
||||
|
||||
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
|
||||
powershellExe = "powershell.exe"
|
||||
)
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func getUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
default:
|
||||
return getUnixUserShell(userID)
|
||||
}
|
||||
}
|
||||
|
||||
// getWindowsUserShell returns the best shell for Windows users.
|
||||
// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
|
||||
// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
|
||||
// PowerShell provides safer argument handling and is available on all modern Windows systems.
|
||||
// Order: pwsh.exe -> powershell.exe
|
||||
func getWindowsUserShell() string {
|
||||
if path, err := exec.LookPath(pwshExe); err == nil {
|
||||
return path
|
||||
}
|
||||
if path, err := exec.LookPath(powershellExe); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
|
||||
}
|
||||
|
||||
// getUnixUserShell returns the shell for Unix-like systems
|
||||
func getUnixUserShell(userID string) string {
|
||||
shell := getShellFromPasswd(userID)
|
||||
if shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := os.Getenv("SHELL"); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
return defaultUnixShell
|
||||
}
|
||||
|
||||
// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
|
||||
func getShellFromPasswd(userID string) string {
|
||||
file, err := os.Open("/etc/passwd")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Warnf("close /etc/passwd file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) < 7 {
|
||||
continue
|
||||
}
|
||||
|
||||
// field 2 is UID
|
||||
if fields[2] == userID {
|
||||
shell := strings.TrimSpace(fields[6])
|
||||
return shell
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Warnf("error reading /etc/passwd: %v", err)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
return []string{
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("LOGNAME=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
"PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games",
|
||||
}
|
||||
}
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func acceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
}
|
||||
|
||||
exactMatches := []string{
|
||||
"LANG",
|
||||
"LANGUAGE",
|
||||
"TERM",
|
||||
"COLORTERM",
|
||||
"EDITOR",
|
||||
"VISUAL",
|
||||
"PAGER",
|
||||
"LESS",
|
||||
"LESSCHARSET",
|
||||
"TZ",
|
||||
}
|
||||
|
||||
prefixMatches := []string{
|
||||
"LC_",
|
||||
}
|
||||
|
||||
for _, exact := range exactMatches {
|
||||
if varName == exact {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range prefixMatches {
|
||||
if strings.HasPrefix(varName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
45
client/ssh/server/test.go
Normal file
45
client/ssh/server/test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func StartTestServer(t *testing.T, server *Server) string {
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
return actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
411
client/ssh/server/user_utils.go
Normal file
411
client/ssh/server/user_utils.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges")
|
||||
ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges")
|
||||
)
|
||||
|
||||
// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.)
|
||||
func isPlatformUnix() bool {
|
||||
return getCurrentOS() != "windows"
|
||||
}
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = user.Current
|
||||
lookupUser = user.Lookup
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
getEuid = os.Geteuid
|
||||
)
|
||||
|
||||
const (
|
||||
// FeatureSSHLogin represents SSH login operations for privilege checking
|
||||
FeatureSSHLogin = "SSH login"
|
||||
// FeatureSFTP represents SFTP operations for privilege checking
|
||||
FeatureSFTP = "SFTP"
|
||||
)
|
||||
|
||||
// PrivilegeCheckRequest represents a privilege check request
|
||||
type PrivilegeCheckRequest struct {
|
||||
// Username being requested (empty = current user)
|
||||
RequestedUsername string
|
||||
FeatureSupportsUserSwitch bool // Does this feature/operation support user switching?
|
||||
FeatureName string
|
||||
}
|
||||
|
||||
// PrivilegeCheckResult represents the result of a privilege check
|
||||
type PrivilegeCheckResult struct {
|
||||
// Allowed indicates whether the privilege check passed
|
||||
Allowed bool
|
||||
// User is the effective user to use for the operation (nil if not allowed)
|
||||
User *user.User
|
||||
// Error contains the reason for denial (nil if allowed)
|
||||
Error error
|
||||
// UsedFallback indicates we fell back to current user instead of requested user.
|
||||
// This happens on Unix when running as an unprivileged user (e.g., in containers)
|
||||
// where there's no point in user switching since we lack privileges anyway.
|
||||
// When true, all privilege checks have already been performed and no additional
|
||||
// privilege dropping or root checks are needed - the current user is the target.
|
||||
UsedFallback bool
|
||||
// RequiresUserSwitching indicates whether user switching will actually occur
|
||||
// (false for fallback cases where no actual switching happens)
|
||||
RequiresUserSwitching bool
|
||||
}
|
||||
|
||||
// CheckPrivileges performs comprehensive privilege checking for all SSH features.
|
||||
// This is the single source of truth for privilege decisions across the SSH server.
|
||||
func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult {
|
||||
context, err := s.buildPrivilegeCheckContext(req.FeatureName)
|
||||
if err != nil {
|
||||
return PrivilegeCheckResult{Allowed: false, Error: err}
|
||||
}
|
||||
|
||||
// Handle empty username case - but still check root access controls
|
||||
if req.RequestedUsername == "" {
|
||||
if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: &PrivilegedUserError{Username: context.currentUser.Username},
|
||||
}
|
||||
}
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: context.currentUser,
|
||||
RequiresUserSwitching: false,
|
||||
}
|
||||
}
|
||||
|
||||
return s.checkUserRequest(context, req)
|
||||
}
|
||||
|
||||
// buildPrivilegeCheckContext gathers all the context needed for privilege checking
|
||||
func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) {
|
||||
currentUser, err := getCurrentUser()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current user for %s: %w", featureName, err)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
allowRoot := s.allowRootLogin
|
||||
s.mu.RUnlock()
|
||||
|
||||
return &privilegeCheckContext{
|
||||
currentUser: currentUser,
|
||||
currentUserPrivileged: getIsProcessPrivileged(),
|
||||
allowRoot: allowRoot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkUserRequest handles normal privilege checking flow for specific usernames
|
||||
func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult {
|
||||
if !ctx.currentUserPrivileged && isPlatformUnix() {
|
||||
log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)",
|
||||
ctx.currentUser.Username, req.FeatureName, req.RequestedUsername)
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: ctx.currentUser,
|
||||
UsedFallback: true,
|
||||
RequiresUserSwitching: false,
|
||||
}
|
||||
}
|
||||
|
||||
resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername)
|
||||
if err != nil {
|
||||
// Calculate if user switching would be required even if lookup failed
|
||||
needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username)
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: err,
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser)
|
||||
|
||||
if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: &PrivilegedUserError{Username: resolvedUser.Username},
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
if needsUserSwitching && !req.FeatureSupportsUserSwitch {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName),
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: resolvedUser,
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRequestedUser resolves a username to its canonical user identity
|
||||
func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) {
|
||||
if requestedUsername == "" {
|
||||
return getCurrentUser()
|
||||
}
|
||||
|
||||
if err := validateUsername(requestedUsername); err != nil {
|
||||
return nil, fmt.Errorf("invalid username: %w", err)
|
||||
}
|
||||
|
||||
u, err := lookupUser(requestedUsername)
|
||||
if err != nil {
|
||||
return nil, &UserNotFoundError{Username: requestedUsername, Cause: err}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// isSameResolvedUser compares two resolved user identities
|
||||
func isSameResolvedUser(user1, user2 *user.User) bool {
|
||||
if user1 == nil || user2 == nil {
|
||||
return user1 == user2
|
||||
}
|
||||
return user1.Uid == user2.Uid
|
||||
}
|
||||
|
||||
// privilegeCheckContext holds all context needed for privilege checking
|
||||
type privilegeCheckContext struct {
|
||||
currentUser *user.User
|
||||
currentUserPrivileged bool
|
||||
allowRoot bool
|
||||
}
|
||||
|
||||
// isSameUser checks if two usernames refer to the same user
|
||||
// SECURITY: This function must be conservative - it should only return true
|
||||
// when we're certain both usernames refer to the exact same user identity
|
||||
func isSameUser(requestedUsername, currentUsername string) bool {
|
||||
// Empty requested username means current user
|
||||
if requestedUsername == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Exact match (most common case)
|
||||
if getCurrentOS() == "windows" {
|
||||
if strings.EqualFold(requestedUsername, currentUsername) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if requestedUsername == currentUsername {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Windows domain resolution: only allow domain stripping when comparing
|
||||
// a bare username against the current user's domain-qualified name
|
||||
if getCurrentOS() == "windows" {
|
||||
return isWindowsSameUser(requestedUsername, currentUsername)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isWindowsSameUser handles Windows-specific user comparison with domain logic
|
||||
func isWindowsSameUser(requestedUsername, currentUsername string) bool {
|
||||
// Extract domain and username parts
|
||||
extractParts := func(name string) (domain, user string) {
|
||||
// Handle DOMAIN\username format
|
||||
if idx := strings.LastIndex(name, `\`); idx != -1 {
|
||||
return name[:idx], name[idx+1:]
|
||||
}
|
||||
// Handle user@domain.com format
|
||||
if idx := strings.Index(name, "@"); idx != -1 {
|
||||
return name[idx+1:], name[:idx]
|
||||
}
|
||||
// No domain specified - local machine
|
||||
return "", name
|
||||
}
|
||||
|
||||
reqDomain, reqUser := extractParts(requestedUsername)
|
||||
curDomain, curUser := extractParts(currentUsername)
|
||||
|
||||
// Case-insensitive username comparison
|
||||
if !strings.EqualFold(reqUser, curUser) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If requested username has no domain, it refers to local machine user
|
||||
// Allow this to match the current user regardless of current user's domain
|
||||
if reqDomain == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If both have domains, they must match exactly (case-insensitive)
|
||||
return strings.EqualFold(reqDomain, curDomain)
|
||||
}
|
||||
|
||||
// SetAllowRootLogin configures root login access
|
||||
func (s *Server) SetAllowRootLogin(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowRootLogin = allow
|
||||
}
|
||||
|
||||
// userNameLookup performs user lookup with root login permission check
|
||||
func (s *Server) userNameLookup(username string) (*user.User, error) {
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSSHLogin,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return result.User, nil
|
||||
}
|
||||
|
||||
// userPrivilegeCheck performs user lookup with full privilege check result
|
||||
func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) {
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSSHLogin,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return result, result.Error
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// isPrivilegedUsername checks if the given username represents a privileged user across platforms.
|
||||
// On Unix: root
|
||||
// On Windows: Administrator, SYSTEM (case-insensitive)
|
||||
// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com"
|
||||
func isPrivilegedUsername(username string) bool {
|
||||
if getCurrentOS() != "windows" {
|
||||
return username == "root"
|
||||
}
|
||||
|
||||
bareUsername := username
|
||||
// Handle Windows domain format: DOMAIN\username
|
||||
if idx := strings.LastIndex(username, `\`); idx != -1 {
|
||||
bareUsername = username[idx+1:]
|
||||
}
|
||||
// Handle email-style format: username@domain.com
|
||||
if idx := strings.Index(bareUsername, "@"); idx != -1 {
|
||||
bareUsername = bareUsername[:idx]
|
||||
}
|
||||
|
||||
return isWindowsPrivilegedUser(bareUsername)
|
||||
}
|
||||
|
||||
// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account
|
||||
func isWindowsPrivilegedUser(bareUsername string) bool {
|
||||
// common privileged usernames (case insensitive)
|
||||
privilegedNames := []string{
|
||||
"administrator",
|
||||
"admin",
|
||||
"root",
|
||||
"system",
|
||||
"localsystem",
|
||||
"networkservice",
|
||||
"localservice",
|
||||
}
|
||||
|
||||
usernameLower := strings.ToLower(bareUsername)
|
||||
for _, privilegedName := range privilegedNames {
|
||||
if usernameLower == privilegedName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// computer accounts (ending with $) are not privileged by themselves
|
||||
// They only gain privileges through group membership or specific SIDs
|
||||
|
||||
if targetUser, err := lookupUser(bareUsername); err == nil {
|
||||
return isWindowsPrivilegedSID(targetUser.Uid)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account
|
||||
func isWindowsPrivilegedSID(sid string) bool {
|
||||
privilegedSIDs := []string{
|
||||
"S-1-5-18", // Local System (SYSTEM)
|
||||
"S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE)
|
||||
"S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE)
|
||||
"S-1-5-32-544", // Administrators group (BUILTIN\Administrators)
|
||||
"S-1-5-500", // Built-in Administrator account (local machine RID 500)
|
||||
}
|
||||
|
||||
for _, privilegedSID := range privilegedSIDs {
|
||||
if sid == privilegedSID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for domain administrator accounts (RID 500 in any domain)
|
||||
// Format: S-1-5-21-domain-domain-domain-500
|
||||
// This is reliable as RID 500 is reserved for the domain Administrator account
|
||||
if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for other well-known privileged RIDs in domain contexts
|
||||
// RID 512 = Domain Admins group, RID 516 = Domain Controllers group
|
||||
if strings.HasPrefix(sid, "S-1-5-21-") {
|
||||
if strings.HasSuffix(sid, "-512") || // Domain Admins group
|
||||
strings.HasSuffix(sid, "-516") || // Domain Controllers group
|
||||
strings.HasSuffix(sid, "-519") { // Enterprise Admins group
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isCurrentProcessPrivileged checks if the current process is running with elevated privileges.
|
||||
// On Unix systems, this means running as root (UID 0).
|
||||
// On Windows, this means running as Administrator or SYSTEM.
|
||||
func isCurrentProcessPrivileged() bool {
|
||||
if getCurrentOS() == "windows" {
|
||||
return isWindowsElevated()
|
||||
}
|
||||
return getEuid() == 0
|
||||
}
|
||||
|
||||
// isWindowsElevated checks if the current process is running with elevated privileges on Windows
|
||||
func isWindowsElevated() bool {
|
||||
currentUser, err := getCurrentUser()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if isWindowsPrivilegedSID(currentUser.Uid) {
|
||||
log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid)
|
||||
return true
|
||||
}
|
||||
|
||||
if isPrivilegedUsername(currentUser.Username) {
|
||||
log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
8
client/ssh/server/user_utils_js.go
Normal file
8
client/ssh/server/user_utils_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
// validateUsername is not supported on JS/WASM
|
||||
func validateUsername(_ string) error {
|
||||
return errNotSupported
|
||||
}
|
||||
908
client/ssh/server/user_utils_test.go
Normal file
908
client/ssh/server/user_utils_test.go
Normal file
@@ -0,0 +1,908 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test helper functions
|
||||
func createTestUser(username, uid, gid, homeDir string) *user.User {
|
||||
return &user.User{
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
Username: username,
|
||||
Name: username,
|
||||
HomeDir: homeDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Test dependency injection setup - injects platform dependencies to test real logic
|
||||
func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() {
|
||||
// Store originals
|
||||
originalGetCurrentUser := getCurrentUser
|
||||
originalLookupUser := lookupUser
|
||||
originalGetCurrentOS := getCurrentOS
|
||||
originalGetEuid := getEuid
|
||||
|
||||
// Reset caches to ensure clean test state
|
||||
|
||||
// Set test values - inject platform dependencies
|
||||
getCurrentUser = func() (*user.User, error) {
|
||||
return currentUser, currentUserErr
|
||||
}
|
||||
|
||||
lookupUser = func(username string) (*user.User, error) {
|
||||
if err, exists := lookupErrors[username]; exists {
|
||||
return nil, err
|
||||
}
|
||||
if userObj, exists := lookupUsers[username]; exists {
|
||||
return userObj, nil
|
||||
}
|
||||
return nil, errors.New("user: unknown user " + username)
|
||||
}
|
||||
|
||||
getCurrentOS = func() string {
|
||||
return os
|
||||
}
|
||||
|
||||
getEuid = func() int {
|
||||
return euid
|
||||
}
|
||||
|
||||
// Mock privilege detection based on the test user
|
||||
getIsProcessPrivileged = func() bool {
|
||||
if currentUser == nil {
|
||||
return false
|
||||
}
|
||||
// Check both username and SID for Windows systems
|
||||
if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) {
|
||||
return true
|
||||
}
|
||||
return isPrivilegedUsername(currentUser.Username)
|
||||
}
|
||||
|
||||
// Return cleanup function
|
||||
return func() {
|
||||
getCurrentUser = originalGetCurrentUser
|
||||
lookupUser = originalLookupUser
|
||||
getCurrentOS = originalGetCurrentOS
|
||||
getEuid = originalGetEuid
|
||||
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
// Reset caches after test
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
os string
|
||||
euid int
|
||||
currentUser *user.User
|
||||
requestedUsername string
|
||||
featureSupportsUserSwitch bool
|
||||
allowRoot bool
|
||||
lookupUsers map[string]*user.User
|
||||
expectedAllowed bool
|
||||
expectedRequiresSwitch bool
|
||||
}{
|
||||
{
|
||||
name: "linux_root_can_switch_to_alice",
|
||||
os: "linux",
|
||||
euid: 0, // Root process
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
},
|
||||
expectedAllowed: true,
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_fallback_to_current_user",
|
||||
os: "linux",
|
||||
euid: 1000, // Non-root process
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "bob",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Should fallback to current user (alice)
|
||||
expectedRequiresSwitch: false, // Fallback means no actual switching
|
||||
},
|
||||
{
|
||||
name: "windows_admin_can_switch_to_alice",
|
||||
os: "windows",
|
||||
euid: 1000, // Irrelevant on Windows
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: true,
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_no_fallback_hard_failure",
|
||||
os: "windows",
|
||||
euid: 1000, // Irrelevant on Windows
|
||||
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
|
||||
requestedUsername: "bob",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"),
|
||||
},
|
||||
expectedAllowed: true, // Let OS decide - deferred security check
|
||||
expectedRequiresSwitch: true, // Different user was requested
|
||||
},
|
||||
// Comprehensive test matrix: non-root linux with different allowRoot settings
|
||||
{
|
||||
name: "linux_non_root_request_root_allowRoot_false",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Fallback allows access regardless of root setting
|
||||
expectedRequiresSwitch: false, // Fallback case, no switching
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_request_root_allowRoot_true",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Should fallback to alice (non-privileged process)
|
||||
expectedRequiresSwitch: false, // Fallback means no actual switching
|
||||
},
|
||||
// Windows admin test matrix
|
||||
{
|
||||
name: "windows_admin_request_root_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_request_root_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Windows user switching should work like Unix
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
// Windows non-admin test matrix
|
||||
{
|
||||
name: "windows_non_admin_request_root_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence)
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_system_account_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_system_account_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // SYSTEM can switch to root
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_request_root_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Let OS decide - deferred security check
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Feature doesn't support user switching scenarios
|
||||
{
|
||||
name: "linux_root_feature_no_user_switching_same_user",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "root", // Same user
|
||||
featureSupportsUserSwitch: false,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Same user should work regardless of feature support
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
{
|
||||
name: "linux_root_feature_no_user_switching_different_user",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: false, // Feature doesn't support switching
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
},
|
||||
expectedAllowed: false, // Should deny because feature doesn't support switching
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Empty username (current user) scenarios
|
||||
{
|
||||
name: "linux_non_root_current_user_empty_username",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "", // Empty = current user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Current user should always work
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
{
|
||||
name: "linux_root_current_user_empty_username_root_not_allowed",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "", // Empty = current user (root)
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false, // Root not allowed
|
||||
expectedAllowed: false, // Should deny root even when it's current user
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
|
||||
// User not found scenarios
|
||||
{
|
||||
name: "linux_root_user_not_found",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "nonexistent",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{}, // No users defined = user not found
|
||||
expectedAllowed: false, // Should fail due to user not found
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Windows feature doesn't support user switching
|
||||
{
|
||||
name: "windows_admin_feature_no_user_switching_different_user",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: false, // Feature doesn't support switching
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: false, // Should deny because feature doesn't support switching
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Windows regular user scenarios (non-admin)
|
||||
{
|
||||
name: "windows_regular_user_same_user",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "alice", // Same user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: true, // Regular user accessing themselves should work
|
||||
expectedRequiresSwitch: false, // No switching for same user
|
||||
},
|
||||
{
|
||||
name: "windows_regular_user_empty_username",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "", // Empty = current user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Current user should always work
|
||||
expectedRequiresSwitch: false, // No switching for current user
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Inject platform dependencies to test real logic
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.expectedAllowed, result.Allowed)
|
||||
assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Fallback mechanism is Unix-specific")
|
||||
}
|
||||
|
||||
// Create test scenario where fallback should occur
|
||||
server := &Server{allowRootLogin: true}
|
||||
|
||||
// Mock dependencies to simulate non-privileged user
|
||||
originalGetCurrentUser := getCurrentUser
|
||||
originalGetIsProcessPrivileged := getIsProcessPrivileged
|
||||
|
||||
defer func() {
|
||||
getCurrentUser = originalGetCurrentUser
|
||||
getIsProcessPrivileged = originalGetIsProcessPrivileged
|
||||
|
||||
}()
|
||||
|
||||
// Set up mocks for fallback scenario
|
||||
getCurrentUser = func() (*user.User, error) {
|
||||
return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil
|
||||
}
|
||||
getIsProcessPrivileged = func() bool { return false } // Non-privileged
|
||||
|
||||
// Request different user - should fallback
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "alice",
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
// Verify fallback occurred
|
||||
assert.True(t, result.Allowed, "Should allow with fallback")
|
||||
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
|
||||
assert.Equal(t, "netbird", result.User.Username, "Should return current user")
|
||||
assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used")
|
||||
|
||||
// Key assertion: When UsedFallback is true, no privilege dropping should be needed
|
||||
// because all privilege checks have already been performed and we're using current user
|
||||
t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed",
|
||||
result.User.Username)
|
||||
}
|
||||
|
||||
func TestPrivilegedUsernameDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
platform string
|
||||
privileged bool
|
||||
}{
|
||||
// Unix/Linux tests
|
||||
{"unix_root", "root", "linux", true},
|
||||
{"unix_regular_user", "alice", "linux", false},
|
||||
{"unix_root_capital", "Root", "linux", false}, // Case-sensitive
|
||||
|
||||
// Windows tests
|
||||
{"windows_administrator", "Administrator", "windows", true},
|
||||
{"windows_system", "SYSTEM", "windows", true},
|
||||
{"windows_admin", "admin", "windows", true},
|
||||
{"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive
|
||||
{"windows_domain_admin", "DOMAIN\\Administrator", "windows", true},
|
||||
{"windows_email_admin", "admin@domain.com", "windows", true},
|
||||
{"windows_regular_user", "alice", "windows", false},
|
||||
{"windows_domain_user", "DOMAIN\\alice", "windows", false},
|
||||
{"windows_localsystem", "localsystem", "windows", true},
|
||||
{"windows_networkservice", "networkservice", "windows", true},
|
||||
{"windows_localservice", "localservice", "windows", true},
|
||||
|
||||
// Computer accounts (these depend on current user context in real implementation)
|
||||
{"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged
|
||||
{"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account
|
||||
|
||||
// Cross-platform
|
||||
{"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock the platform for this test
|
||||
cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
result := isPrivilegedUsername(tt.username)
|
||||
assert.Equal(t, tt.privileged, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsPrivilegedSIDDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sid string
|
||||
privileged bool
|
||||
description string
|
||||
}{
|
||||
// Well-known system accounts
|
||||
{"system_account", "S-1-5-18", true, "Local System (SYSTEM)"},
|
||||
{"local_service", "S-1-5-19", true, "Local Service"},
|
||||
{"network_service", "S-1-5-20", true, "Network Service"},
|
||||
{"administrators_group", "S-1-5-32-544", true, "Administrators group"},
|
||||
{"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"},
|
||||
|
||||
// Domain accounts
|
||||
{"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"},
|
||||
{"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"},
|
||||
{"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"},
|
||||
{"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"},
|
||||
|
||||
// Regular users
|
||||
{"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"},
|
||||
{"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"},
|
||||
{"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"},
|
||||
|
||||
// Groups that are not privileged
|
||||
{"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"},
|
||||
{"power_users", "S-1-5-32-547", false, "Power Users group"},
|
||||
|
||||
// Invalid SIDs
|
||||
{"malformed_sid", "S-1-5-invalid", false, "Malformed SID"},
|
||||
{"empty_sid", "", false, "Empty SID"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isWindowsPrivilegedSID(tt.sid)
|
||||
assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSameUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user1 string
|
||||
user2 string
|
||||
os string
|
||||
expected bool
|
||||
}{
|
||||
// Basic cases
|
||||
{"same_username", "alice", "alice", "linux", true},
|
||||
{"different_username", "alice", "bob", "linux", false},
|
||||
|
||||
// Linux (no domain processing)
|
||||
{"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false},
|
||||
{"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false},
|
||||
{"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true},
|
||||
|
||||
// Windows (with domain processing) - Note: parameter order is (requested, current, os, expected)
|
||||
{"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user
|
||||
{"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user
|
||||
{"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users
|
||||
{"windows_case_insensitive", "Alice", "alice", "windows", true},
|
||||
{"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up OS mock
|
||||
cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
result := isSameUser(tt.user1, tt.user2)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernameValidation_Unix(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix-specific username validation tests")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid usernames (Unix/POSIX)
|
||||
{"valid_alphanumeric", "user123", false, ""},
|
||||
{"valid_with_dots", "user.name", false, ""},
|
||||
{"valid_with_hyphens", "user-name", false, ""},
|
||||
{"valid_with_underscores", "user_name", false, ""},
|
||||
{"valid_uppercase", "UserName", false, ""},
|
||||
{"valid_starting_with_digit", "123user", false, ""},
|
||||
{"valid_starting_with_dot", ".hidden", false, ""},
|
||||
|
||||
// Invalid usernames (Unix/POSIX)
|
||||
{"empty_username", "", true, "username cannot be empty"},
|
||||
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
|
||||
{"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction
|
||||
{"username_with_spaces", "user name", true, "invalid characters"},
|
||||
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
|
||||
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
|
||||
{"username_with_pipe", "user|rm", true, "invalid characters"},
|
||||
{"username_with_ampersand", "user&rm", true, "invalid characters"},
|
||||
{"username_with_quotes", "user\"name", true, "invalid characters"},
|
||||
{"username_with_newline", "user\nname", true, "invalid characters"},
|
||||
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
|
||||
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
|
||||
{"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames
|
||||
{"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.username)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Should reject invalid username")
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "Should accept valid username")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernameValidation_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("Windows-specific username validation tests")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid usernames (Windows)
|
||||
{"valid_alphanumeric", "user123", false, ""},
|
||||
{"valid_with_dots", "user.name", false, ""},
|
||||
{"valid_with_hyphens", "user-name", false, ""},
|
||||
{"valid_with_underscores", "user_name", false, ""},
|
||||
{"valid_uppercase", "UserName", false, ""},
|
||||
{"valid_starting_with_digit", "123user", false, ""},
|
||||
{"valid_starting_with_dot", ".hidden", false, ""},
|
||||
{"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this
|
||||
{"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format
|
||||
{"valid_email_username", "user@domain.com", false, ""}, // Windows email format
|
||||
{"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format
|
||||
|
||||
// Invalid usernames (Windows)
|
||||
{"empty_username", "", true, "username cannot be empty"},
|
||||
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
|
||||
{"username_with_spaces", "user name", true, "invalid characters"},
|
||||
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
|
||||
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
|
||||
{"username_with_pipe", "user|rm", true, "invalid characters"},
|
||||
{"username_with_ampersand", "user&rm", true, "invalid characters"},
|
||||
{"username_with_quotes", "user\"name", true, "invalid characters"},
|
||||
{"username_with_newline", "user\nname", true, "invalid characters"},
|
||||
{"username_with_brackets", "user[name]", true, "invalid characters"},
|
||||
{"username_with_colon", "user:name", true, "invalid characters"},
|
||||
{"username_with_semicolon", "user;name", true, "invalid characters"},
|
||||
{"username_with_equals", "user=name", true, "invalid characters"},
|
||||
{"username_with_comma", "user,name", true, "invalid characters"},
|
||||
{"username_with_plus", "user+name", true, "invalid characters"},
|
||||
{"username_with_asterisk", "user*name", true, "invalid characters"},
|
||||
{"username_with_question", "user?name", true, "invalid characters"},
|
||||
{"username_with_angles", "user<name>", true, "invalid characters"},
|
||||
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
|
||||
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
|
||||
{"username_ending_with_period", "user.", true, "cannot end with a period"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.username)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Should reject invalid username")
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "Should accept valid username")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test real-world integration scenarios with actual platform capabilities
|
||||
func TestCheckPrivileges_RealWorldScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
feature string
|
||||
featureSupportsUserSwitch bool
|
||||
requestedUsername string
|
||||
allowRoot bool
|
||||
expectedBehaviorPattern string
|
||||
}{
|
||||
{"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"},
|
||||
{"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"},
|
||||
{"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"},
|
||||
{"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"},
|
||||
{"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock privileged environment to ensure consistent test behavior across environments
|
||||
cleanup := setupTestDependencies(
|
||||
createTestUser("root", "0", "0", "/root"), // Running as root
|
||||
nil,
|
||||
runtime.GOOS,
|
||||
0, // euid 0 (root)
|
||||
map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
"differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
|
||||
FeatureName: tt.feature,
|
||||
})
|
||||
|
||||
switch tt.expectedBehaviorPattern {
|
||||
case "should_allow_current_user":
|
||||
assert.True(t, result.Allowed, "Should allow current user access")
|
||||
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
|
||||
case "should_deny_root":
|
||||
assert.False(t, result.Allowed, "Should deny root when not allowed")
|
||||
assert.Contains(t, result.Error.Error(), "root", "Should mention root in error")
|
||||
case "should_deny_switching":
|
||||
assert.False(t, result.Allowed, "Should deny when feature doesn't support switching")
|
||||
assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual platform capabilities - no mocking
|
||||
func TestCheckPrivileges_ActualPlatform(t *testing.T) {
|
||||
// This test uses the REAL platform capabilities
|
||||
server := &Server{allowRootLogin: true}
|
||||
|
||||
// Test current user access - should always work
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "", // Current user
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.True(t, result.Allowed, "Current user should always be allowed")
|
||||
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
|
||||
assert.NotNil(t, result.User, "Should return current user")
|
||||
|
||||
// Test user switching capability based on actual platform
|
||||
actualIsPrivileged := isCurrentProcessPrivileged() // REAL check
|
||||
actualOS := runtime.GOOS // REAL check
|
||||
|
||||
t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v",
|
||||
actualOS, actualIsPrivileged, actualIsPrivileged)
|
||||
|
||||
// Test requesting different user
|
||||
result = server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "nonexistentuser",
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
switch {
|
||||
case actualOS == "windows":
|
||||
// Windows supports user switching but should fail on nonexistent user
|
||||
assert.False(t, result.Allowed, "Windows should deny nonexistent user")
|
||||
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
|
||||
assert.Contains(t, result.Error.Error(), "not found",
|
||||
"Should indicate user not found")
|
||||
case !actualIsPrivileged:
|
||||
// Non-privileged Unix processes should fallback to current user
|
||||
assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
|
||||
assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
|
||||
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
|
||||
assert.NotNil(t, result.User, "Should return current user")
|
||||
default:
|
||||
// Privileged Unix processes should attempt user lookup
|
||||
assert.False(t, result.Allowed, "Should fail due to nonexistent user")
|
||||
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
|
||||
assert.Contains(t, result.Error.Error(), "nonexistentuser",
|
||||
"Should indicate user not found")
|
||||
}
|
||||
}
|
||||
|
||||
// Test platform detection logic with dependency injection
|
||||
func TestPlatformLogic_DependencyInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
os string
|
||||
euid int
|
||||
currentUser *user.User
|
||||
expectedIsProcessPrivileged bool
|
||||
expectedSupportsUserSwitching bool
|
||||
}{
|
||||
{
|
||||
name: "linux_root_process",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
expectedIsProcessPrivileged: true,
|
||||
expectedSupportsUserSwitching: true,
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_process",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
expectedIsProcessPrivileged: false,
|
||||
expectedSupportsUserSwitching: false,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_process",
|
||||
os: "windows",
|
||||
euid: 1000, // euid ignored on Windows
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
expectedIsProcessPrivileged: true,
|
||||
expectedSupportsUserSwitching: true, // Windows supports user switching when privileged
|
||||
},
|
||||
{
|
||||
name: "windows_regular_process",
|
||||
os: "windows",
|
||||
euid: 1000, // euid ignored on Windows
|
||||
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
|
||||
expectedIsProcessPrivileged: false,
|
||||
expectedSupportsUserSwitching: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Inject platform dependencies and test REAL logic
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
// Test the actual functions with injected dependencies
|
||||
actualIsPrivileged := isCurrentProcessPrivileged()
|
||||
actualSupportsUserSwitching := actualIsPrivileged
|
||||
|
||||
assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged,
|
||||
"isCurrentProcessPrivileged() result mismatch")
|
||||
assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching,
|
||||
"supportsUserSwitching() result mismatch")
|
||||
|
||||
t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username)
|
||||
t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v",
|
||||
actualIsPrivileged, actualSupportsUserSwitching)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) {
|
||||
// Test Windows elevated user switching scenarios with simplified privilege logic
|
||||
tests := []struct {
|
||||
name string
|
||||
currentUser *user.User
|
||||
requestedUsername string
|
||||
allowRoot bool
|
||||
expectedAllowed bool
|
||||
expectedErrorContains string
|
||||
}{
|
||||
{
|
||||
name: "windows_admin_can_switch_to_alice",
|
||||
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_can_try_switch",
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"),
|
||||
requestedUsername: "bob",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Privilege check allows it, OS will reject during execution
|
||||
},
|
||||
{
|
||||
name: "windows_system_can_switch_to_alice",
|
||||
currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"),
|
||||
requestedUsername: "alice",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_root_not_allowed",
|
||||
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
allowRoot: false,
|
||||
expectedAllowed: false,
|
||||
expectedErrorContains: "privileged user login is disabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup test dependencies with Windows OS and specified privileges
|
||||
lookupUsers := map[string]*user.User{
|
||||
tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername),
|
||||
}
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.expectedAllowed, result.Allowed,
|
||||
"Privilege check result should match expected for %s", tt.name)
|
||||
|
||||
if !tt.expectedAllowed && tt.expectedErrorContains != "" {
|
||||
assert.NotNil(t, result.Error, "Should have error when not allowed")
|
||||
assert.Contains(t, result.Error.Error(), tt.expectedErrorContains,
|
||||
"Error should contain expected message")
|
||||
}
|
||||
|
||||
if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername {
|
||||
assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
8
client/ssh/server/userswitching_js.go
Normal file
8
client/ssh/server/userswitching_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
// enableUserSwitching is not supported on JS/WASM
|
||||
func enableUserSwitching() error {
|
||||
return errNotSupported
|
||||
}
|
||||
228
client/ssh/server/userswitching_unix.go
Normal file
228
client/ssh/server/userswitching_unix.go
Normal file
@@ -0,0 +1,228 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// POSIX portable filename character set regex: [a-zA-Z0-9._-]
|
||||
// First character cannot be hyphen (POSIX requirement)
|
||||
var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`)
|
||||
|
||||
// validateUsername validates that a username conforms to POSIX standards with security considerations
|
||||
func validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return errors.New("username cannot be empty")
|
||||
}
|
||||
|
||||
// POSIX allows up to 256 characters, but practical limit is 32 for compatibility
|
||||
if len(username) > 32 {
|
||||
return errors.New("username too long (max 32 characters)")
|
||||
}
|
||||
|
||||
if !posixUsernameRegex.MatchString(username) {
|
||||
return errors.New("username contains invalid characters (must match POSIX portable filename character set)")
|
||||
}
|
||||
|
||||
if username == "." || username == ".." {
|
||||
return fmt.Errorf("username cannot be '.' or '..'")
|
||||
}
|
||||
|
||||
// Warn if username is fully numeric (can cause issues with UID/username ambiguity)
|
||||
if isFullyNumeric(username) {
|
||||
log.Warnf("fully numeric username '%s' may cause issues with some commands", username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isFullyNumeric checks if username contains only digits
|
||||
func isFullyNumeric(username string) bool {
|
||||
for _, char := range username {
|
||||
if char < '0' || char > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// createPtyLoginCommand creates a Pty command using login for privileged processes
|
||||
func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||
loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get login command: %w", err)
|
||||
}
|
||||
|
||||
execCmd := exec.CommandContext(session.Context(), loginPath, args...)
|
||||
execCmd.Dir = localUser.HomeDir
|
||||
execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
return execCmd, nil
|
||||
}
|
||||
|
||||
// getLoginCmd returns the login command and args for privileged Pty user switching
|
||||
func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) {
|
||||
loginPath, err := exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("login command not available: %w", err)
|
||||
}
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", username, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// fileExists checks if a file exists (helper for login command logic)
|
||||
func (s *Server) fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// parseUserCredentials extracts numeric UID, GID, and supplementary groups
|
||||
func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) {
|
||||
uid64, err := strconv.ParseUint(localUser.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err)
|
||||
}
|
||||
uid := uint32(uid64)
|
||||
|
||||
gid64, err := strconv.ParseUint(localUser.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err)
|
||||
}
|
||||
gid := uint32(gid64)
|
||||
|
||||
groups, err := s.getSupplementaryGroups(localUser.Username)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err)
|
||||
groups = []uint32{gid}
|
||||
}
|
||||
|
||||
return uid, gid, groups, nil
|
||||
}
|
||||
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user
|
||||
func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
groupIDStrings, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", username, err)
|
||||
}
|
||||
|
||||
groups := make([]uint32, len(groupIDStrings))
|
||||
for i, gidStr := range groupIDStrings {
|
||||
gid64, err := strconv.ParseUint(gidStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err)
|
||||
}
|
||||
groups[i] = uint32(gid64)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, fmt.Errorf("invalid username: %w", err)
|
||||
}
|
||||
|
||||
uid, gid, groups, err := s.parseUserCredentials(localUser)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
config := ExecutorConfig{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
|
||||
return privilegeDropper.CreateExecutorCommand(session.Context(), config)
|
||||
}
|
||||
|
||||
// enableUserSwitching is a no-op on Unix systems
|
||||
func enableUserSwitching() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
|
||||
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||
localUser := privilegeResult.User
|
||||
|
||||
if privilegeResult.UsedFallback {
|
||||
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
|
||||
}
|
||||
|
||||
return s.createPtyLoginCommand(localUser, ptyReq, session)
|
||||
}
|
||||
|
||||
// createDirectPtyCommand creates a direct Pty command without privilege dropping
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// preparePtyEnv prepares environment variables for Pty execution
|
||||
func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string {
|
||||
termType := ptyReq.Term
|
||||
if termType == "" {
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
259
client/ssh/server/userswitching_windows.go
Normal file
259
client/ssh/server/userswitching_windows.go
Normal file
@@ -0,0 +1,259 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
func validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return fmt.Errorf("username cannot be empty")
|
||||
}
|
||||
|
||||
usernameToValidate := extractUsernameFromDomain(username)
|
||||
|
||||
if err := validateUsernameLength(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateUsernameCharacters(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateUsernameFormat(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUsernameFromDomain extracts the username part from domain\username or username@domain format
|
||||
func extractUsernameFromDomain(username string) string {
|
||||
if idx := strings.LastIndex(username, `\`); idx != -1 {
|
||||
return username[idx+1:]
|
||||
}
|
||||
if idx := strings.Index(username, "@"); idx != -1 {
|
||||
return username[:idx]
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// validateUsernameLength checks if username length is within Windows limits
|
||||
func validateUsernameLength(username string) error {
|
||||
if len(username) > 20 {
|
||||
return fmt.Errorf("username too long (max 20 characters for Windows)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUsernameCharacters checks for invalid characters in Windows usernames
|
||||
func validateUsernameCharacters(username string) error {
|
||||
invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'}
|
||||
for _, char := range username {
|
||||
for _, invalid := range invalidChars {
|
||||
if char == invalid {
|
||||
return fmt.Errorf("username contains invalid characters")
|
||||
}
|
||||
}
|
||||
if char < 32 || char == 127 {
|
||||
return fmt.Errorf("username contains control characters")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUsernameFormat checks for invalid username formats and patterns
|
||||
func validateUsernameFormat(username string) error {
|
||||
if username == "." || username == ".." {
|
||||
return fmt.Errorf("username cannot be '.' or '..'")
|
||||
}
|
||||
|
||||
if strings.HasSuffix(username, ".") {
|
||||
return fmt.Errorf("username cannot end with a period")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createExecutorCommand creates a command using Windows executor for privilege dropping
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
username, _ := s.parseUsername(localUser.Username)
|
||||
if err := validateUsername(username); err != nil {
|
||||
return nil, fmt.Errorf("invalid username: %w", err)
|
||||
}
|
||||
|
||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||
}
|
||||
|
||||
// createUserSwitchCommand creates a command with Windows user switching
|
||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
if rawCmd != "" {
|
||||
command = rawCmd
|
||||
}
|
||||
|
||||
config := WindowsExecutorConfig{
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Interactive: interactive || (rawCmd == ""),
|
||||
}
|
||||
|
||||
dropper := NewPrivilegeDropper()
|
||||
return dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||
}
|
||||
|
||||
// parseUsername extracts username and domain from a Windows username
|
||||
func (s *Server) parseUsername(fullUsername string) (username, domain string) {
|
||||
// Handle DOMAIN\username format
|
||||
if idx := strings.LastIndex(fullUsername, `\`); idx != -1 {
|
||||
domain = fullUsername[:idx]
|
||||
username = fullUsername[idx+1:]
|
||||
return username, domain
|
||||
}
|
||||
|
||||
// Handle username@domain format
|
||||
if username, domain, ok := strings.Cut(fullUsername, "@"); ok {
|
||||
return username, domain
|
||||
}
|
||||
|
||||
// Local user (no domain)
|
||||
return fullUsername, "."
|
||||
}
|
||||
|
||||
// hasPrivilege checks if the current process has a specific privilege
|
||||
func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) {
|
||||
var luid windows.LUID
|
||||
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
|
||||
return false, fmt.Errorf("lookup privilege value: %w", err)
|
||||
}
|
||||
|
||||
var returnLength uint32
|
||||
err := windows.GetTokenInformation(
|
||||
windows.Token(token),
|
||||
windows.TokenPrivileges,
|
||||
nil, // null buffer to get size
|
||||
0,
|
||||
&returnLength,
|
||||
)
|
||||
|
||||
if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
||||
return false, fmt.Errorf("get token information size: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, returnLength)
|
||||
err = windows.GetTokenInformation(
|
||||
windows.Token(token),
|
||||
windows.TokenPrivileges,
|
||||
&buffer[0],
|
||||
returnLength,
|
||||
&returnLength,
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get token information: %w", err)
|
||||
}
|
||||
|
||||
privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0]))
|
||||
|
||||
// Check if the privilege is present and enabled
|
||||
for i := uint32(0); i < privileges.PrivilegeCount; i++ {
|
||||
privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer(
|
||||
uintptr(unsafe.Pointer(&privileges.Privileges[0])) +
|
||||
uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}),
|
||||
))
|
||||
if privilege.Luid == luid {
|
||||
return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// enablePrivilege enables a specific privilege for the current process token
|
||||
// This is required because privileges like SeAssignPrimaryTokenPrivilege are present
|
||||
// but disabled by default, even for the SYSTEM account
|
||||
func enablePrivilege(token windows.Handle, privilegeName string) error {
|
||||
var luid windows.LUID
|
||||
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
|
||||
return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err)
|
||||
}
|
||||
|
||||
privileges := windows.Tokenprivileges{
|
||||
PrivilegeCount: 1,
|
||||
Privileges: [1]windows.LUIDAndAttributes{
|
||||
{
|
||||
Luid: luid,
|
||||
Attributes: windows.SE_PRIVILEGE_ENABLED,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := windows.AdjustTokenPrivileges(
|
||||
windows.Token(token),
|
||||
false,
|
||||
&privileges,
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err)
|
||||
}
|
||||
|
||||
hasPriv, err := hasPrivilege(token, privilegeName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err)
|
||||
}
|
||||
if !hasPriv {
|
||||
return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName)
|
||||
}
|
||||
|
||||
log.Debugf("Successfully enabled privilege %s for current process", privilegeName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// enableUserSwitching enables required privileges for Windows user switching
|
||||
func enableUserSwitching() error {
|
||||
process := windows.CurrentProcess()
|
||||
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(
|
||||
process,
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY,
|
||||
&token,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open process token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("Failed to close process token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil {
|
||||
return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err)
|
||||
}
|
||||
log.Infof("Windows user switching privileges enabled successfully")
|
||||
return nil
|
||||
}
|
||||
473
client/ssh/server/winpty/conpty.go
Normal file
473
client/ssh/server/winpty/conpty.go
Normal file
@@ -0,0 +1,473 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyEnvironment = errors.New("empty environment")
|
||||
)
|
||||
|
||||
const (
|
||||
extendedStartupInfoPresent = 0x00080000
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
procThreadAttributePseudoConsole = 0x00020016
|
||||
|
||||
PowerShellCommandFlag = "-Command"
|
||||
|
||||
errCloseInputRead = "close input read handle: %v"
|
||||
errCloseConPtyCleanup = "close ConPty handle during cleanup"
|
||||
)
|
||||
|
||||
// PtyConfig holds configuration for Pty execution.
|
||||
type PtyConfig struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
WorkingDir string
|
||||
}
|
||||
|
||||
// UserConfig holds user execution configuration.
|
||||
type UserConfig struct {
|
||||
Token windows.Handle
|
||||
Environment []string
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
|
||||
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
|
||||
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
|
||||
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
|
||||
)
|
||||
|
||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||
commandLine := buildCommandLine(args)
|
||||
|
||||
config := ExecutionConfig{
|
||||
Pty: ptyConfig,
|
||||
User: userConfig,
|
||||
Session: session,
|
||||
Context: ctx,
|
||||
}
|
||||
|
||||
return executeConPtyWithConfig(commandLine, config)
|
||||
}
|
||||
|
||||
// ExecutionConfig holds all execution configuration.
|
||||
type ExecutionConfig struct {
|
||||
Pty PtyConfig
|
||||
User UserConfig
|
||||
Session ssh.Session
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
// executeConPtyWithConfig creates ConPty and executes process with configuration.
|
||||
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
|
||||
ctx := config.Context
|
||||
session := config.Session
|
||||
width := config.Pty.Width
|
||||
height := config.Pty.Height
|
||||
userToken := config.User.Token
|
||||
userEnv := config.User.Environment
|
||||
workingDir := config.Pty.WorkingDir
|
||||
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty pipes: %w", err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(width, height, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty: %w", err)
|
||||
}
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(userToken)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("duplicate to primary token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(primaryToken); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
siEx, err := setupConPtyStartupInfo(hPty)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("setup startup info: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
|
||||
}()
|
||||
|
||||
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("create process as user with ConPty: %w", err)
|
||||
}
|
||||
defer closeProcessInfo(pi)
|
||||
|
||||
if err := windows.CloseHandle(inputRead); err != nil {
|
||||
log.Debugf(errCloseInputRead, err)
|
||||
}
|
||||
if err := windows.CloseHandle(outputWrite); err != nil {
|
||||
log.Debugf("close output write handle: %v", err)
|
||||
}
|
||||
|
||||
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, pi.Process)
|
||||
}
|
||||
|
||||
// createConPtyPipes creates input/output pipes for ConPty.
|
||||
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
|
||||
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
|
||||
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
|
||||
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
|
||||
log.Debugf(errCloseInputRead, closeErr)
|
||||
}
|
||||
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
|
||||
log.Debugf("close input write handle: %v", closeErr)
|
||||
}
|
||||
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
|
||||
}
|
||||
|
||||
return inputRead, inputWrite, outputRead, outputWrite, nil
|
||||
}
|
||||
|
||||
// createConPty creates a Windows ConPty with the specified size and pipe handles.
|
||||
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
|
||||
size := windows.Coord{X: int16(width), Y: int16(height)}
|
||||
|
||||
var hPty windows.Handle
|
||||
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
|
||||
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
|
||||
}
|
||||
|
||||
return hPty, nil
|
||||
}
|
||||
|
||||
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
|
||||
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
|
||||
var siEx windows.StartupInfoEx
|
||||
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
|
||||
|
||||
var attrListSize uintptr
|
||||
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
|
||||
if ret == 0 && attrListSize == 0 {
|
||||
return nil, fmt.Errorf("get attribute list size")
|
||||
}
|
||||
|
||||
attrListBytes := make([]byte, attrListSize)
|
||||
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
|
||||
|
||||
ret, _, err := procInitializeProcThreadAttributeList.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
1,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&attrListSize)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("initialize attribute list: %w", err)
|
||||
}
|
||||
|
||||
ret, _, err = procUpdateProcThreadAttribute.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
0,
|
||||
procThreadAttributePseudoConsole,
|
||||
uintptr(hPty),
|
||||
unsafe.Sizeof(hPty),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("update thread attribute: %w", err)
|
||||
}
|
||||
|
||||
return &siEx, nil
|
||||
}
|
||||
|
||||
// createConPtyProcess creates the actual process with ConPty.
|
||||
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
|
||||
var pi windows.ProcessInformation
|
||||
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
|
||||
|
||||
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
|
||||
}
|
||||
|
||||
envPtr, err := convertEnvironmentToUTF16(userEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var workingDirPtr *uint16
|
||||
if workingDir != "" {
|
||||
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
|
||||
siEx.StartupInfo.StdInput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdOutput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
|
||||
|
||||
if userToken != windows.InvalidHandle {
|
||||
err = windows.CreateProcessAsUser(
|
||||
windows.Token(userToken),
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
} else {
|
||||
err = windows.CreateProcess(
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process: %w", err)
|
||||
}
|
||||
|
||||
return &pi, nil
|
||||
}
|
||||
|
||||
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
|
||||
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
|
||||
if len(userEnv) == 0 {
|
||||
// Return nil pointer for empty environment - Windows API will inherit parent environment
|
||||
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
|
||||
}
|
||||
|
||||
var envUTF16 []uint16
|
||||
for _, envVar := range userEnv {
|
||||
if envVar != "" {
|
||||
utf16Str, err := windows.UTF16FromString(envVar)
|
||||
if err != nil {
|
||||
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
|
||||
continue
|
||||
}
|
||||
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
}
|
||||
}
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
|
||||
if len(envUTF16) > 0 {
|
||||
return &envUTF16[0], nil
|
||||
}
|
||||
// Return nil pointer when no valid environment variables found
|
||||
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
|
||||
}
|
||||
|
||||
// duplicateToPrimaryToken converts an impersonation token to a primary token.
|
||||
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
|
||||
var primaryToken windows.Handle
|
||||
if err := windows.DuplicateTokenEx(
|
||||
windows.Token(token),
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityImpersonation,
|
||||
windows.TokenPrimary,
|
||||
(*windows.Token)(&primaryToken),
|
||||
); err != nil {
|
||||
return 0, fmt.Errorf("duplicate token: %w", err)
|
||||
}
|
||||
return primaryToken, nil
|
||||
}
|
||||
|
||||
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
|
||||
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, process windows.Handle) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
|
||||
|
||||
processErr := waitForProcess(ctx, process)
|
||||
if processErr != nil {
|
||||
return processErr
|
||||
}
|
||||
|
||||
// Clean up in the original order after process completes
|
||||
if err := reader.Close(); err != nil {
|
||||
log.Debugf("close reader: %v", err)
|
||||
}
|
||||
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
if ret == 0 {
|
||||
log.Debugf("close ConPty handle: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := windows.CloseHandle(outputRead); err != nil {
|
||||
log.Debugf("close output read handle: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startIOBridging starts the I/O bridging goroutines.
|
||||
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
|
||||
wg.Add(2)
|
||||
|
||||
// Input: reader (SSH session) -> inputWrite (ConPty)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(inputWrite); err != nil {
|
||||
log.Debugf("close input write handle in goroutine: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
|
||||
log.Debugf("input copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Output: outputRead (ConPty) -> writer (SSH session)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
|
||||
log.Debugf("output copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// waitForProcess waits for process completion with context cancellation.
|
||||
func waitForProcess(ctx context.Context, process windows.Handle) error {
|
||||
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
|
||||
return fmt.Errorf("wait for process %d: %w", process, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildShellArgs builds shell arguments for ConPty execution.
|
||||
func buildShellArgs(shell, command string) []string {
|
||||
if command != "" {
|
||||
return []string{shell, PowerShellCommandFlag, command}
|
||||
}
|
||||
return []string{shell}
|
||||
}
|
||||
|
||||
// buildCommandLine builds a Windows command line from arguments using proper escaping.
|
||||
func buildCommandLine(args []string) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for i, arg := range args {
|
||||
if i > 0 {
|
||||
result.WriteString(" ")
|
||||
}
|
||||
result.WriteString(syscall.EscapeArg(arg))
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// closeHandles closes multiple Windows handles.
|
||||
func closeHandles(handles ...windows.Handle) {
|
||||
for _, handle := range handles {
|
||||
if handle != windows.InvalidHandle {
|
||||
if err := windows.CloseHandle(handle); err != nil {
|
||||
log.Debugf("close handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeProcessInfo closes process and thread handles.
|
||||
func closeProcessInfo(pi *windows.ProcessInformation) {
|
||||
if pi != nil {
|
||||
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||
log.Debugf("close process handle: %v", err)
|
||||
}
|
||||
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||
log.Debugf("close thread handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// windowsHandleReader wraps a Windows handle for reading.
|
||||
type windowsHandleReader struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
|
||||
var bytesRead uint32
|
||||
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Close() error {
|
||||
return windows.CloseHandle(r.handle)
|
||||
}
|
||||
|
||||
// windowsHandleWriter wraps a Windows handle for writing.
|
||||
type windowsHandleWriter struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
|
||||
var bytesWritten uint32
|
||||
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesWritten), nil
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Close() error {
|
||||
return windows.CloseHandle(w.handle)
|
||||
}
|
||||
289
client/ssh/server/winpty/conpty_test.go
Normal file
289
client/ssh/server/winpty/conpty_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestBuildShellArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shell string
|
||||
command string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Shell with command",
|
||||
shell: "powershell.exe",
|
||||
command: "Get-Process",
|
||||
expected: []string{"powershell.exe", "-Command", "Get-Process"},
|
||||
},
|
||||
{
|
||||
name: "CMD with command",
|
||||
shell: "cmd.exe",
|
||||
command: "dir",
|
||||
expected: []string{"cmd.exe", "-Command", "dir"},
|
||||
},
|
||||
{
|
||||
name: "Shell interactive",
|
||||
shell: "powershell.exe",
|
||||
command: "",
|
||||
expected: []string{"powershell.exe"},
|
||||
},
|
||||
{
|
||||
name: "CMD interactive",
|
||||
shell: "cmd.exe",
|
||||
command: "",
|
||||
expected: []string{"cmd.exe"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildShellArgs(tt.shell, tt.command)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple args",
|
||||
args: []string{"cmd.exe", "/c", "echo"},
|
||||
expected: "cmd.exe /c echo",
|
||||
},
|
||||
{
|
||||
name: "Args with spaces",
|
||||
args: []string{"Program Files\\app.exe", "arg with spaces"},
|
||||
expected: `"Program Files\app.exe" "arg with spaces"`,
|
||||
},
|
||||
{
|
||||
name: "Args with quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "hello world"`},
|
||||
expected: `cmd.exe /c "echo \"hello world\""`,
|
||||
},
|
||||
{
|
||||
name: "PowerShell calling PowerShell",
|
||||
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
|
||||
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
|
||||
},
|
||||
{
|
||||
name: "Complex nested quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
|
||||
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
|
||||
},
|
||||
{
|
||||
name: "Path with spaces and args",
|
||||
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
|
||||
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
|
||||
},
|
||||
{
|
||||
name: "Empty argument",
|
||||
args: []string{"cmd.exe", "/c", "echo", ""},
|
||||
expected: `cmd.exe /c echo ""`,
|
||||
},
|
||||
{
|
||||
name: "Argument with backslashes",
|
||||
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
|
||||
expected: `robocopy C:\Source\ C:\Dest\ /E`,
|
||||
},
|
||||
{
|
||||
name: "Empty args",
|
||||
args: []string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single arg with space",
|
||||
args: []string{"path with spaces"},
|
||||
expected: `"path with spaces"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildCommandLine(tt.args)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConPtyPipes(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
|
||||
// Verify all handles are valid
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
|
||||
|
||||
// Clean up handles
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
|
||||
func TestCreateConPty(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
require.NoError(t, err, "Should create ConPty successfully")
|
||||
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
|
||||
|
||||
// Clean up ConPty
|
||||
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
|
||||
}
|
||||
|
||||
func TestConvertEnvironmentToUTF16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userEnv []string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid environment variables",
|
||||
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty environment",
|
||||
userEnv: []string{},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Environment with empty strings",
|
||||
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := convertEnvironmentToUTF16(tt.userEnv)
|
||||
if tt.hasError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if len(tt.userEnv) == 0 {
|
||||
assert.Nil(t, result, "Empty environment should return nil")
|
||||
} else {
|
||||
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateToPrimaryToken(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping token tests in short mode")
|
||||
}
|
||||
|
||||
// Get current process token for testing
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
|
||||
require.NoError(t, err, "Should open current process token")
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
t.Logf("Failed to close token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
|
||||
require.NoError(t, err, "Should duplicate token to primary")
|
||||
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
|
||||
|
||||
// Clean up
|
||||
err = windows.CloseHandle(primaryToken)
|
||||
assert.NoError(t, err, "Should close primary token")
|
||||
}
|
||||
|
||||
func TestWindowsHandleReader(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Write test data
|
||||
testData := []byte("Hello, Windows Handle Reader!")
|
||||
var bytesWritten uint32
|
||||
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
|
||||
require.NoError(t, err, "Should write test data")
|
||||
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
|
||||
|
||||
// Close write handle to signal EOF
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
|
||||
// Test reading
|
||||
reader := &windowsHandleReader{handle: readHandle}
|
||||
buffer := make([]byte, len(testData))
|
||||
n, err := reader.Read(buffer)
|
||||
require.NoError(t, err, "Should read from handle")
|
||||
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read expected data")
|
||||
}
|
||||
|
||||
func TestWindowsHandleWriter(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Test writing
|
||||
testData := []byte("Hello, Windows Handle Writer!")
|
||||
writer := &windowsHandleWriter{handle: writeHandle}
|
||||
n, err := writer.Write(testData)
|
||||
require.NoError(t, err, "Should write to handle")
|
||||
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
|
||||
|
||||
// Close write handle
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was written by reading it back
|
||||
buffer := make([]byte, len(testData))
|
||||
var bytesRead uint32
|
||||
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
|
||||
require.NoError(t, err, "Should read back written data")
|
||||
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read back expected data")
|
||||
}
|
||||
|
||||
// BenchmarkConPtyCreation benchmarks ConPty creation performance
|
||||
func BenchmarkConPtyCreation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 {
|
||||
log.Debugf("ClosePseudoConsole failed: %v", err)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user