Files
netbird/client/ssh/config/manager.go
2025-11-17 17:10:41 +01:00

283 lines
7.0 KiB
Go

package config
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const (
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
MaxPeersForSSHConfig = 200
fileWriteTimeout = 2 * time.Second
)
func isSSHConfigDisabled() bool {
value := os.Getenv(EnvDisableSSHConfig)
if value == "" {
return false
}
disabled, err := strconv.ParseBool(value)
if err != nil {
return true
}
return disabled
}
func isSSHConfigForced() bool {
value := os.Getenv(EnvForceSSHConfig)
if value == "" {
return false
}
forced, err := strconv.ParseBool(value)
if err != nil {
return true
}
return forced
}
// shouldGenerateSSHConfig checks if SSH config should be generated based on peer count
func shouldGenerateSSHConfig(peerCount int) bool {
if isSSHConfigDisabled() {
return false
}
if isSSHConfigForced() {
return true
}
return peerCount <= MaxPeersForSSHConfig
}
// writeFileWithTimeout writes data to a file with a timeout
func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error {
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- os.WriteFile(filename, data, perm)
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
}
}
// Manager handles SSH client configuration for NetBird peers
type Manager struct {
sshConfigDir string
sshConfigFile string
}
// PeerSSHInfo represents a peer's SSH configuration information
type PeerSSHInfo struct {
Hostname string
IP string
FQDN string
}
// New creates a new SSH config manager
func New() *Manager {
sshConfigDir := getSystemSSHConfigDir()
return &Manager{
sshConfigDir: sshConfigDir,
sshConfigFile: nbssh.NetBirdSSHConfigFile,
}
}
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
func getSystemSSHConfigDir() string {
if runtime.GOOS == "windows" {
return getWindowsSSHConfigDir()
}
return nbssh.UnixSSHConfigDir
}
func getWindowsSSHConfigDir() string {
programData := os.Getenv("PROGRAMDATA")
if programData == "" {
programData = `C:\ProgramData`
}
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
}
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
if !shouldGenerateSSHConfig(len(peers)) {
m.logSkipReason(len(peers))
return nil
}
sshConfig, err := m.buildSSHConfig(peers)
if err != nil {
return fmt.Errorf("build SSH config: %w", err)
}
return m.writeSSHConfig(sshConfig)
}
func (m *Manager) logSkipReason(peerCount int) {
if isSSHConfigDisabled() {
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
} else {
log.Infof("SSH config generation skipped: too many peers (%d > %d). Use %s=true to force.",
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
}
}
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
sshConfig := m.buildConfigHeader()
var allHostPatterns []string
for _, peer := range peers {
hostPatterns := m.buildHostPatterns(peer)
allHostPatterns = append(allHostPatterns, hostPatterns...)
}
if len(allHostPatterns) > 0 {
peerConfig, err := m.buildPeerConfig(allHostPatterns)
if err != nil {
return "", err
}
sshConfig += peerConfig
}
return sshConfig, nil
}
func (m *Manager) buildConfigHeader() string {
return "# NetBird SSH client configuration\n" +
"# Generated automatically - do not edit manually\n" +
"#\n" +
"# To disable SSH config management, use:\n" +
"# netbird service reconfigure --service-env NB_DISABLE_SSH_CONFIG=true\n" +
"#\n\n"
}
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
uniquePatterns := make(map[string]bool)
var deduplicatedPatterns []string
for _, pattern := range allHostPatterns {
if !uniquePatterns[pattern] {
uniquePatterns[pattern] = true
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
}
}
execPath, err := m.getNetBirdExecutablePath()
if err != nil {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
if runtime.GOOS == "windows" {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
} else {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
}
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
var hostPatterns []string
if peer.IP != "" {
hostPatterns = append(hostPatterns, peer.IP)
}
if peer.FQDN != "" {
hostPatterns = append(hostPatterns, peer.FQDN)
}
if peer.Hostname != "" && peer.Hostname != peer.FQDN {
hostPatterns = append(hostPatterns, peer.Hostname)
}
return hostPatterns
}
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil
}
// RemoveSSHClientConfig removes NetBird SSH configuration
func (m *Manager) RemoveSSHClientConfig() error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
err := os.Remove(sshConfigPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
}
if err == nil {
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
}
return nil
}
func (m *Manager) getNetBirdExecutablePath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("retrieve executable path: %w", err)
}
realPath, err := filepath.EvalSymlinks(execPath)
if err != nil {
log.Debugf("symlink resolution failed: %v", err)
return execPath, nil
}
return realPath, nil
}
// GetSSHConfigDir returns the SSH config directory path
func (m *Manager) GetSSHConfigDir() string {
return m.sshConfigDir
}
// GetSSHConfigFile returns the SSH config file name
func (m *Manager) GetSSHConfigFile() string {
return m.sshConfigFile
}