Files
netbird/client/ssh/server/shell.go

185 lines
4.4 KiB
Go

package server
import (
"bufio"
"fmt"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
)
const (
defaultUnixShell = "/bin/sh"
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
powershellExe = "powershell.exe"
)
// getUserShell returns the appropriate shell for the given user ID
// Handles all platform-specific logic and fallbacks consistently
func getUserShell(userID string) string {
switch runtime.GOOS {
case "windows":
return getWindowsUserShell()
default:
return getUnixUserShell(userID)
}
}
// getWindowsUserShell returns the best shell for Windows users.
// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
// PowerShell provides safer argument handling and is available on all modern Windows systems.
// Order: pwsh.exe -> powershell.exe
func getWindowsUserShell() string {
if path, err := exec.LookPath(pwshExe); err == nil {
return path
}
if path, err := exec.LookPath(powershellExe); err == nil {
return path
}
return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
}
// getUnixUserShell returns the shell for Unix-like systems.
// Tries /etc/passwd first (fast, no subprocess), falls back to getent for NSS users.
func getUnixUserShell(userID string) string {
if shell := getShellFromPasswd(userID); shell != "" {
return shell
}
if shell := getShellFromGetent(userID); shell != "" {
return shell
}
if shell := os.Getenv("SHELL"); shell != "" {
return shell
}
return defaultUnixShell
}
// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
func getShellFromPasswd(userID string) string {
file, err := os.Open("/etc/passwd")
if err != nil {
return ""
}
defer func() {
if err := file.Close(); err != nil {
log.Warnf("close /etc/passwd file: %v", err)
}
}()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Split(line, ":")
if len(fields) < 7 {
continue
}
// field 2 is UID
if fields[2] == userID {
shell := strings.TrimSpace(fields[6])
return shell
}
}
if err := scanner.Err(); err != nil {
log.Warnf("error reading /etc/passwd: %v", err)
}
return ""
}
// prepareUserEnv prepares environment variables for user execution
func prepareUserEnv(user *user.User, shell string) []string {
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
if runtime.GOOS == "windows" {
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
}
return []string{
fmt.Sprint("SHELL=" + shell),
fmt.Sprint("USER=" + user.Username),
fmt.Sprint("LOGNAME=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir),
"PATH=" + pathValue,
}
}
// acceptEnv checks if environment variable from SSH client should be accepted
// This is a whitelist of variables that SSH clients can send to the server
func acceptEnv(envVar string) bool {
varName := envVar
if idx := strings.Index(envVar, "="); idx != -1 {
varName = envVar[:idx]
}
exactMatches := []string{
"LANG",
"LANGUAGE",
"TERM",
"COLORTERM",
"EDITOR",
"VISUAL",
"PAGER",
"LESS",
"LESSCHARSET",
"TZ",
}
prefixMatches := []string{
"LC_",
}
for _, exact := range exactMatches {
if varName == exact {
return true
}
}
for _, prefix := range prefixMatches {
if strings.HasPrefix(varName, prefix) {
return true
}
}
return false
}
// prepareSSHEnv prepares SSH protocol-specific environment variables
// These variables provide information about the SSH connection itself
func prepareSSHEnv(session ssh.Session) []string {
remoteAddr := session.RemoteAddr()
localAddr := session.LocalAddr()
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
if err != nil {
remoteHost = remoteAddr.String()
remotePort = "0"
}
localHost, localPort, err := net.SplitHostPort(localAddr.String())
if err != nil {
localHost = localAddr.String()
localPort = strconv.Itoa(InternalSSHPort)
}
return []string{
// SSH_CLIENT format: "client_ip client_port server_port"
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
}
}