Files
netbird/client/ssh/server/server_test.go
2025-11-17 17:10:41 +01:00

442 lines
11 KiB
Go

package server
import (
"context"
"fmt"
"net"
"net/netip"
"os/user"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop()
assert.NoError(t, err)
}
func TestSSHServerIntegration(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with random port
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server in background
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
// Get a free port
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key for verification
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Create SSH client config
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// Connect to SSH server
client, err := cryptossh.Dial("tcp", serverAddr, config)
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("close client: %v", err)
}
}()
// Test creating a session
session, err := client.NewSession()
require.NoError(t, err)
defer func() {
if err := session.Close(); err != nil {
t.Logf("close session: %v", err)
}
}()
// Note: Since we don't have a real shell environment in tests,
// we can't test actual command execution, but we can verify
// the connection and authentication work
t.Log("SSH connection and authentication successful")
}
func TestSSHServerMultipleConnections(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// Test multiple concurrent connections
const numConnections = 5
results := make(chan error, numConnections)
for i := 0; i < numConnections; i++ {
go func(id int) {
client, err := cryptossh.Dial("tcp", serverAddr, config)
if err != nil {
results <- fmt.Errorf("connection %d failed: %w", id, err)
return
}
defer func() {
_ = client.Close() // Ignore error in test goroutine
}()
session, err := client.NewSession()
if err != nil {
results <- fmt.Errorf("session %d failed: %w", id, err)
return
}
defer func() {
_ = session.Close() // Ignore error in test goroutine
}()
results <- nil
}(i)
}
// Wait for all connections to complete
for i := 0; i < numConnections; i++ {
select {
case err := <-results:
assert.NoError(t, err)
case <-time.After(10 * time.Second):
t.Fatalf("Connection %d timed out", i)
}
}
}
func TestSSHServerNoAuthMode(t *testing.T) {
// Generate host key for server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Generate a client private key for SSH protocol (server doesn't check it)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with client key
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(clientSigner),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil {
assert.NoError(t, conn.Close())
}
}
func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles
for i := 0; i < 3; i++ {
t.Logf("Start/stop cycle %d", i+1)
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case <-started:
case err := <-errChan:
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
case <-time.After(5 * time.Second):
t.Fatalf("Cycle %d: Server start timeout", i+1)
}
err = server.Stop()
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
}
}
func TestSSHServer_WindowsShellHandling(t *testing.T) {
if testing.Short() {
t.Skip("Skipping Windows shell test in short mode")
}
server := &Server{}
if runtime.GOOS == "windows" {
// Test Windows cmd.exe shell behavior
args := server.getShellCommandArgs("cmd.exe", "echo test")
assert.Equal(t, "cmd.exe", args[0])
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
// Test PowerShell behavior
args = server.getShellCommandArgs("powershell.exe", "echo test")
assert.Equal(t, "powershell.exe", args[0])
assert.Equal(t, "-Command", args[1])
assert.Equal(t, "echo test", args[2])
} else {
// Test Unix shell behavior
args := server.getShellCommandArgs("/bin/sh", "echo test")
assert.Equal(t, "/bin/sh", args[0])
assert.Equal(t, "-l", args[1])
assert.Equal(t, "-c", args[2])
assert.Equal(t, "echo test", args[3])
}
}
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
serverConfig1 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
server2.SetAllowLocalPortForwarding(true)
server2.SetAllowRemotePortForwarding(true)
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
}