mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:24:18 -04:00
442 lines
11 KiB
Go
442 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os/user"
|
|
"runtime"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
cryptossh "golang.org/x/crypto/ssh"
|
|
|
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
)
|
|
|
|
func TestServer_StartStop(t *testing.T) {
|
|
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
serverConfig := &Config{
|
|
HostKeyPEM: key,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
|
|
err = server.Stop()
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestSSHServerIntegration(t *testing.T) {
|
|
// Generate host key for server
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Generate client key pair
|
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Create server with random port
|
|
serverConfig := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
|
|
// Start server in background
|
|
serverAddr := "127.0.0.1:0"
|
|
started := make(chan string, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
// Get a free port
|
|
ln, err := net.Listen("tcp", serverAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
actualAddr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
errChan <- fmt.Errorf("close temp listener: %w", err)
|
|
return
|
|
}
|
|
|
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
started <- actualAddr
|
|
}()
|
|
|
|
select {
|
|
case actualAddr := <-started:
|
|
serverAddr = actualAddr
|
|
case err := <-errChan:
|
|
t.Fatalf("Server failed to start: %v", err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Server start timeout")
|
|
}
|
|
|
|
defer func() {
|
|
err := server.Stop()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
// Parse client private key
|
|
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
// Parse server host key for verification
|
|
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
|
require.NoError(t, err)
|
|
hostPubKey := hostPrivParsed.PublicKey()
|
|
|
|
// Get current user for SSH connection
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err, "Should be able to get current user for test")
|
|
|
|
// Create SSH client config
|
|
config := &cryptossh.ClientConfig{
|
|
User: currentUser.Username,
|
|
Auth: []cryptossh.AuthMethod{
|
|
cryptossh.PublicKeys(signer),
|
|
},
|
|
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
|
Timeout: 3 * time.Second,
|
|
}
|
|
|
|
// Connect to SSH server
|
|
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
if err := client.Close(); err != nil {
|
|
t.Logf("close client: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Test creating a session
|
|
session, err := client.NewSession()
|
|
require.NoError(t, err)
|
|
defer func() {
|
|
if err := session.Close(); err != nil {
|
|
t.Logf("close session: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Note: Since we don't have a real shell environment in tests,
|
|
// we can't test actual command execution, but we can verify
|
|
// the connection and authentication work
|
|
t.Log("SSH connection and authentication successful")
|
|
}
|
|
|
|
func TestSSHServerMultipleConnections(t *testing.T) {
|
|
// Generate host key for server
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Generate client key pair
|
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Create server
|
|
serverConfig := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
|
|
// Start server
|
|
serverAddr := "127.0.0.1:0"
|
|
started := make(chan string, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
ln, err := net.Listen("tcp", serverAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
actualAddr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
errChan <- fmt.Errorf("close temp listener: %w", err)
|
|
return
|
|
}
|
|
|
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
started <- actualAddr
|
|
}()
|
|
|
|
select {
|
|
case actualAddr := <-started:
|
|
serverAddr = actualAddr
|
|
case err := <-errChan:
|
|
t.Fatalf("Server failed to start: %v", err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Server start timeout")
|
|
}
|
|
|
|
defer func() {
|
|
err := server.Stop()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
// Parse client private key
|
|
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
// Parse server host key
|
|
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
|
require.NoError(t, err)
|
|
hostPubKey := hostPrivParsed.PublicKey()
|
|
|
|
// Get current user for SSH connection
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err, "Should be able to get current user for test")
|
|
|
|
config := &cryptossh.ClientConfig{
|
|
User: currentUser.Username,
|
|
Auth: []cryptossh.AuthMethod{
|
|
cryptossh.PublicKeys(signer),
|
|
},
|
|
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
|
Timeout: 3 * time.Second,
|
|
}
|
|
|
|
// Test multiple concurrent connections
|
|
const numConnections = 5
|
|
results := make(chan error, numConnections)
|
|
|
|
for i := 0; i < numConnections; i++ {
|
|
go func(id int) {
|
|
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
|
if err != nil {
|
|
results <- fmt.Errorf("connection %d failed: %w", id, err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = client.Close() // Ignore error in test goroutine
|
|
}()
|
|
|
|
session, err := client.NewSession()
|
|
if err != nil {
|
|
results <- fmt.Errorf("session %d failed: %w", id, err)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = session.Close() // Ignore error in test goroutine
|
|
}()
|
|
|
|
results <- nil
|
|
}(i)
|
|
}
|
|
|
|
// Wait for all connections to complete
|
|
for i := 0; i < numConnections; i++ {
|
|
select {
|
|
case err := <-results:
|
|
assert.NoError(t, err)
|
|
case <-time.After(10 * time.Second):
|
|
t.Fatalf("Connection %d timed out", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestSSHServerNoAuthMode(t *testing.T) {
|
|
// Generate host key for server
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Create server
|
|
serverConfig := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
|
|
// Start server
|
|
serverAddr := "127.0.0.1:0"
|
|
started := make(chan string, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
ln, err := net.Listen("tcp", serverAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
actualAddr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
errChan <- fmt.Errorf("close temp listener: %w", err)
|
|
return
|
|
}
|
|
|
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
started <- actualAddr
|
|
}()
|
|
|
|
select {
|
|
case actualAddr := <-started:
|
|
serverAddr = actualAddr
|
|
case err := <-errChan:
|
|
t.Fatalf("Server failed to start: %v", err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Server start timeout")
|
|
}
|
|
|
|
defer func() {
|
|
err := server.Stop()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
// Generate a client private key for SSH protocol (server doesn't check it)
|
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
// Parse server host key
|
|
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
|
require.NoError(t, err)
|
|
hostPubKey := hostPrivParsed.PublicKey()
|
|
|
|
// Get current user for SSH connection
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err, "Should be able to get current user for test")
|
|
|
|
// Try to connect with client key
|
|
config := &cryptossh.ClientConfig{
|
|
User: currentUser.Username,
|
|
Auth: []cryptossh.AuthMethod{
|
|
cryptossh.PublicKeys(clientSigner),
|
|
},
|
|
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
|
Timeout: 3 * time.Second,
|
|
}
|
|
|
|
// This should succeed in no-auth mode (server doesn't verify keys)
|
|
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
|
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
|
if conn != nil {
|
|
assert.NoError(t, conn.Close())
|
|
}
|
|
}
|
|
|
|
func TestSSHServerStartStopCycle(t *testing.T) {
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
serverConfig := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
serverAddr := "127.0.0.1:0"
|
|
|
|
// Test multiple start/stop cycles
|
|
for i := 0; i < 3; i++ {
|
|
t.Logf("Start/stop cycle %d", i+1)
|
|
|
|
started := make(chan string, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
ln, err := net.Listen("tcp", serverAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
actualAddr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
errChan <- fmt.Errorf("close temp listener: %w", err)
|
|
return
|
|
}
|
|
|
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
started <- actualAddr
|
|
}()
|
|
|
|
select {
|
|
case <-started:
|
|
case err := <-errChan:
|
|
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Cycle %d: Server start timeout", i+1)
|
|
}
|
|
|
|
err = server.Stop()
|
|
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
|
|
}
|
|
}
|
|
|
|
func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping Windows shell test in short mode")
|
|
}
|
|
|
|
server := &Server{}
|
|
|
|
if runtime.GOOS == "windows" {
|
|
// Test Windows cmd.exe shell behavior
|
|
args := server.getShellCommandArgs("cmd.exe", "echo test")
|
|
assert.Equal(t, "cmd.exe", args[0])
|
|
assert.Equal(t, "-Command", args[1])
|
|
assert.Equal(t, "echo test", args[2])
|
|
|
|
// Test PowerShell behavior
|
|
args = server.getShellCommandArgs("powershell.exe", "echo test")
|
|
assert.Equal(t, "powershell.exe", args[0])
|
|
assert.Equal(t, "-Command", args[1])
|
|
assert.Equal(t, "echo test", args[2])
|
|
} else {
|
|
// Test Unix shell behavior
|
|
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
|
assert.Equal(t, "/bin/sh", args[0])
|
|
assert.Equal(t, "-l", args[1])
|
|
assert.Equal(t, "-c", args[2])
|
|
assert.Equal(t, "echo test", args[3])
|
|
}
|
|
}
|
|
|
|
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
serverConfig1 := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server1 := New(serverConfig1)
|
|
|
|
serverConfig2 := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server2 := New(serverConfig2)
|
|
|
|
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
|
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
|
|
|
server2.SetAllowLocalPortForwarding(true)
|
|
server2.SetAllowRemotePortForwarding(true)
|
|
|
|
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
|
|
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
|
|
}
|