mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
196 lines
5.9 KiB
Go
196 lines
5.9 KiB
Go
package ssh
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
"github.com/netbirdio/netbird/client/proto"
|
|
)
|
|
|
|
const (
|
|
NetBirdSSHConfigFile = "99-netbird.conf"
|
|
|
|
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
|
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
|
)
|
|
|
|
var (
|
|
// ErrPeerNotFound indicates the peer was not found in the network
|
|
ErrPeerNotFound = errors.New("peer not found in network")
|
|
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
|
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
|
)
|
|
|
|
// HostKeyVerifier provides SSH host key verification
|
|
type HostKeyVerifier interface {
|
|
VerifySSHHostKey(peerAddress string, key []byte) error
|
|
}
|
|
|
|
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
|
type DaemonHostKeyVerifier struct {
|
|
client proto.DaemonServiceClient
|
|
}
|
|
|
|
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
|
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
|
return &DaemonHostKeyVerifier{
|
|
client: client,
|
|
}
|
|
}
|
|
|
|
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
|
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
|
PeerAddress: peerAddress,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !response.GetFound() {
|
|
return ErrPeerNotFound
|
|
}
|
|
|
|
storedKeyData := response.GetSshHostKey()
|
|
|
|
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
|
}
|
|
|
|
// printAuthInstructions prints authentication instructions to stderr
|
|
func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) {
|
|
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
|
|
|
if browserWillOpen {
|
|
_, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.")
|
|
_, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:")
|
|
_, _ = fmt.Fprintln(stderr)
|
|
}
|
|
|
|
_, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete)
|
|
|
|
if authResponse.UserCode != "" {
|
|
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
|
}
|
|
|
|
if browserWillOpen {
|
|
_, _ = fmt.Fprintln(stderr)
|
|
}
|
|
|
|
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
|
}
|
|
|
|
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
|
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) {
|
|
req := &proto.RequestJWTAuthRequest{}
|
|
if hint != "" {
|
|
req.Hint = &hint
|
|
}
|
|
authResponse, err := client.RequestJWTAuth(ctx, req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("request JWT auth: %w", err)
|
|
}
|
|
|
|
if useCache && authResponse.CachedToken != "" {
|
|
log.Debug("Using cached authentication token")
|
|
return authResponse.CachedToken, nil
|
|
}
|
|
|
|
if stderr != nil {
|
|
printAuthInstructions(stderr, authResponse, openBrowser != nil)
|
|
}
|
|
|
|
if openBrowser != nil {
|
|
if err := openBrowser(authResponse.VerificationURIComplete); err != nil {
|
|
log.Debugf("open browser: %v", err)
|
|
}
|
|
}
|
|
|
|
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
|
DeviceCode: authResponse.DeviceCode,
|
|
UserCode: authResponse.UserCode,
|
|
})
|
|
if err != nil {
|
|
return "", fmt.Errorf("wait for JWT token: %w", err)
|
|
}
|
|
|
|
if stdout != nil {
|
|
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
|
}
|
|
return tokenResponse.Token, nil
|
|
}
|
|
|
|
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
|
// Returns nil only if the presented key matches the stored key.
|
|
// Returns ErrNoStoredKey if storedKeyData is empty.
|
|
// Returns an error if the keys don't match or if parsing fails.
|
|
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
|
if len(storedKeyData) == 0 {
|
|
return ErrNoStoredKey
|
|
}
|
|
|
|
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
|
if err != nil {
|
|
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
|
}
|
|
|
|
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
|
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
|
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
|
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
|
configWithJWT := *config
|
|
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
|
return &configWithJWT
|
|
}
|
|
|
|
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
|
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
|
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
addresses := buildAddressList(hostname, remote)
|
|
presentedKey := key.Marshal()
|
|
|
|
for _, addr := range addresses {
|
|
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
|
if errors.Is(err, ErrPeerNotFound) {
|
|
// Try other addresses for this peer
|
|
continue
|
|
}
|
|
return err
|
|
}
|
|
// Verified
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
|
}
|
|
}
|
|
|
|
// buildAddressList creates a list of addresses to check for host key verification.
|
|
// It includes the original hostname and extracts the host part from the remote address if different.
|
|
func buildAddressList(hostname string, remote net.Addr) []string {
|
|
addresses := []string{hostname}
|
|
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
|
if host != hostname {
|
|
addresses = append(addresses, host)
|
|
}
|
|
}
|
|
return addresses
|
|
}
|