Add IPv6 support to SSH server, client config, and netflow logger

This commit is contained in:
Viktor Liu
2026-03-24 12:06:58 +01:00
parent 71962f88f8
commit d81cd5d154
10 changed files with 136 additions and 44 deletions

View File

@@ -41,6 +41,14 @@ func (e *Engine) setupSSHPortRedirection() error {
}
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
if v6 := e.wgInterface.Address().IPv6; v6.IsValid() {
if err := e.firewall.AddInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil {
log.Warnf("failed to add IPv6 SSH port redirection: %v", err)
} else {
log.Infof("SSH port redirection enabled: [%s]:22 -> [%s]:22022", v6, v6)
}
}
return nil
}
@@ -137,12 +145,13 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
continue
}
peerIP := e.extractPeerIP(peerConfig)
peerIP, peerIPv6 := e.extractPeerIPs(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
IPv6: peerIPv6,
FQDN: peerConfig.GetFqdn(),
})
}
@@ -150,16 +159,26 @@ func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []
return peerInfo
}
// extractPeerIP extracts IP address from peer's allowed IPs
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
if len(peerConfig.GetAllowedIps()) == 0 {
return ""
// extractPeerIPs extracts IPv4 and IPv6 overlay addresses from peer's allowed IPs.
// Only considers host routes (/32, /128) within the overlay networks to avoid
// picking up routed prefixes or static routes like 2620:fe::fe/128.
func (e *Engine) extractPeerIPs(peerConfig *mgmProto.RemotePeerConfig) (v4, v6 netip.Addr) {
wgAddr := e.wgInterface.Address()
for _, allowedIP := range peerConfig.GetAllowedIps() {
prefix, err := netip.ParsePrefix(allowedIP)
if err != nil {
log.Warnf("failed to parse AllowedIP %q: %v", allowedIP, err)
continue
}
addr := prefix.Addr().Unmap()
switch {
case addr.Is4() && prefix.Bits() == 32 && wgAddr.Network.Contains(addr) && !v4.IsValid():
v4 = addr
case addr.Is6() && prefix.Bits() == 128 && wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(addr) && !v6.IsValid():
v6 = addr
}
}
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
return prefix.Addr().String()
}
return ""
return v4, v6
}
// extractHostname extracts short hostname from FQDN
@@ -208,7 +227,7 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress || peerState.IPv6 == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
@@ -262,6 +281,13 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
return fmt.Errorf("start SSH server: %w", err)
}
if v6 := wgAddr.IPv6; v6.IsValid() {
v6Addr := netip.AddrPortFrom(v6, sshserver.InternalSSHPort)
if err := server.AddListener(e.ctx, v6Addr); err != nil {
log.Warnf("failed to add IPv6 SSH listener: %v", err)
}
}
e.sshServer = server
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
@@ -330,6 +356,12 @@ func (e *Engine) cleanupSSHPortRedirection() error {
}
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
if v6 := e.wgInterface.Address().IPv6; v6.IsValid() {
if err := e.firewall.RemoveInboundDNAT(v6, firewallManager.ProtocolTCP, 22, 22022); err != nil {
log.Debugf("failed to remove IPv6 SSH port redirection: %v", err)
}
}
return nil
}

View File

@@ -188,7 +188,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
srcPort = flow.TupleOrig.Proto.SourcePort
dstPort = flow.TupleOrig.Proto.DestinationPort
case nftypes.ICMP:
case nftypes.ICMP, nftypes.ICMPv6:
icmpType = flow.TupleOrig.Proto.ICMPType
icmpCode = flow.TupleOrig.Proto.ICMPCode
}
@@ -231,8 +231,14 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
}
// fallback if mark rules are not in place
wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
addr := c.iface.Address()
if addr.Network.Contains(srcIP) || addr.Network.Contains(dstIP) {
return true
}
if addr.IPv6Net.IsValid() {
return addr.IPv6Net.Contains(srcIP) || addr.IPv6Net.Contains(dstIP)
}
return false
}
// mapRxPackets maps packet counts to RX based on flow direction
@@ -291,17 +297,16 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
}
// fallback if marks are not set
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
addr := c.iface.Address()
switch {
case wgaddr == srcIP:
case addr.IP == srcIP || (addr.IPv6.IsValid() && addr.IPv6 == srcIP):
return nftypes.Egress
case wgaddr == dstIP:
case addr.IP == dstIP || (addr.IPv6.IsValid() && addr.IPv6 == dstIP):
return nftypes.Ingress
case wgnetwork.Contains(srcIP):
case addr.Network.Contains(srcIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(srcIP)):
// netbird network -> resource network
return nftypes.Ingress
case wgnetwork.Contains(dstIP):
case addr.Network.Contains(dstIP) || (addr.IPv6Net.IsValid() && addr.IPv6Net.Contains(dstIP)):
// resource network -> netbird network
return nftypes.Egress
}

View File

@@ -24,15 +24,17 @@ type Logger struct {
cancel context.CancelFunc
statusRecorder *peer.Status
wgIfaceNet netip.Prefix
wgIfaceNetV6 netip.Prefix
dnsCollection atomic.Bool
exitNodeCollection atomic.Bool
Store types.Store
}
func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
func New(statusRecorder *peer.Status, wgIfaceIPNet, wgIfaceIPNetV6 netip.Prefix) *Logger {
return &Logger{
statusRecorder: statusRecorder,
wgIfaceNet: wgIfaceIPNet,
wgIfaceNetV6: wgIfaceIPNetV6,
Store: store.NewMemoryStore(),
}
}
@@ -88,11 +90,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool
var isDestExitNode bool
if !l.wgIfaceNet.Contains(event.SourceIP) {
if !l.isOverlayIP(event.SourceIP) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
}
if !l.wgIfaceNet.Contains(event.DestIP) {
if !l.isOverlayIP(event.DestIP) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
}
@@ -136,6 +138,10 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
l.exitNodeCollection.Store(exitNodeCollection)
}
func (l *Logger) isOverlayIP(ip netip.Addr) bool {
return l.wgIfaceNet.Contains(ip) || (l.wgIfaceNetV6.IsValid() && l.wgIfaceNetV6.Contains(ip))
}
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP &&

View File

@@ -12,7 +12,7 @@ import (
)
func TestStore(t *testing.T) {
logger := logger.New(nil, netip.Prefix{})
logger := logger.New(nil, netip.Prefix{}, netip.Prefix{})
logger.Enable()
event := types.EventFields{

View File

@@ -35,11 +35,12 @@ type Manager struct {
// NewManager creates a new netflow manager
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
var prefix netip.Prefix
var prefix, prefixV6 netip.Prefix
if iface != nil {
prefix = iface.Address().Network
prefixV6 = iface.Address().IPv6Net
}
flowLogger := logger.New(statusRecorder, prefix)
flowLogger := logger.New(statusRecorder, prefix, prefixV6)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -269,7 +270,7 @@ func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
},
}
if event.Protocol == nftypes.ICMP {
if event.Protocol == nftypes.ICMP || event.Protocol == nftypes.ICMPv6 {
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
IcmpInfo: &proto.ICMPInfo{
IcmpType: uint32(event.ICMPType),

View File

@@ -19,6 +19,7 @@ const (
ICMP = Protocol(1)
TCP = Protocol(6)
UDP = Protocol(17)
ICMPv6 = Protocol(58)
SCTP = Protocol(132)
)
@@ -30,6 +31,8 @@ func (p Protocol) String() string {
return "TCP"
case 17:
return "UDP"
case 58:
return "ICMPv6"
case 132:
return "SCTP"
default:

View File

@@ -75,7 +75,7 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if err != nil {
return fmt.Errorf("failed to parse rosenpass address: %w", err)
}
peerAddr := fmt.Sprintf("%s:%s", wireGuardIP, strPort)
peerAddr := net.JoinHostPort(wireGuardIP, strPort)
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
}

View File

@@ -3,6 +3,7 @@ package config
import (
"context"
"fmt"
"net/netip"
"os"
"path/filepath"
"runtime"
@@ -91,7 +92,8 @@ type Manager struct {
// PeerSSHInfo represents a peer's SSH configuration information
type PeerSSHInfo struct {
Hostname string
IP string
IP netip.Addr
IPv6 netip.Addr
FQDN string
}
@@ -211,8 +213,11 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
var hostPatterns []string
if peer.IP != "" {
hostPatterns = append(hostPatterns, peer.IP)
if peer.IP.IsValid() {
hostPatterns = append(hostPatterns, peer.IP.String())
}
if peer.IPv6.IsValid() {
hostPatterns = append(hostPatterns, peer.IPv6.String())
}
if peer.FQDN != "" {
hostPatterns = append(hostPatterns, peer.FQDN)

View File

@@ -2,6 +2,7 @@ package config
import (
"fmt"
"net/netip"
"os"
"path/filepath"
"runtime"
@@ -28,12 +29,12 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
peers := []PeerSSHInfo{
{
Hostname: "peer1",
IP: "100.125.1.1",
IP: netip.MustParseAddr("100.125.1.1"),
FQDN: "peer1.nb.internal",
},
{
Hostname: "peer2",
IP: "100.125.1.2",
IP: netip.MustParseAddr("100.125.1.2"),
FQDN: "peer2.nb.internal",
},
}
@@ -101,7 +102,7 @@ func TestManager_PeerLimit(t *testing.T) {
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
})
}
@@ -136,7 +137,7 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
IP: netip.MustParseAddr(fmt.Sprintf("100.125.1.%d", i%254+1)),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
})
}

View File

@@ -137,10 +137,11 @@ type sessionState struct {
}
type Server struct {
sshServer *ssh.Server
listener net.Listener
mu sync.RWMutex
hostKeyPEM []byte
sshServer *ssh.Server
listener net.Listener
extraListeners []net.Listener
mu sync.RWMutex
hostKeyPEM []byte
// sessions tracks active SSH sessions (shell, command, SFTP).
// These are created when a client opens a session channel and requests shell/exec/subsystem.
@@ -254,6 +255,35 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return nil
}
// AddListener starts serving SSH on an additional address (e.g. IPv6).
// Must be called after Start.
func (s *Server) AddListener(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
srv := s.sshServer
if srv == nil {
s.mu.Unlock()
return errors.New("SSH server is not running")
}
ln, addrDesc, err := s.createListener(ctx, addr)
if err != nil {
s.mu.Unlock()
return fmt.Errorf("create listener: %w", err)
}
s.extraListeners = append(s.extraListeners, ln)
s.mu.Unlock()
log.Infof("SSH server also listening on %s", addrDesc)
go func() {
if err := srv.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
log.Errorf("SSH server error on %s: %v", addrDesc, err)
}
}()
return nil
}
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
@@ -294,6 +324,13 @@ func (s *Server) Stop() error {
log.Debugf("close SSH server: %v", err)
}
for _, ln := range s.extraListeners {
if err := ln.Close(); err != nil {
log.Debugf("close extra SSH listener: %v", err)
}
}
s.extraListeners = nil
s.sshServer = nil
s.listener = nil
@@ -746,11 +783,10 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) sessionKey {
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
localIP := s.wgAddress.IP
wgAddr := s.wgAddress
s.mu.RUnlock()
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
if !wgAddr.Network.IsValid() || !wgAddr.IP.IsValid() {
return conn
}
@@ -766,14 +802,17 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
return nil
}
remoteIP = remoteIP.Unmap()
// Block connections from our own IP (prevent local apps from connecting to ourselves)
if remoteIP == localIP {
if remoteIP == wgAddr.IP || wgAddr.IPv6.IsValid() && remoteIP == wgAddr.IPv6 {
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
return nil
}
if !netbirdNetwork.Contains(remoteIP) {
inV4 := wgAddr.Network.Contains(remoteIP)
inV6 := wgAddr.IPv6Net.IsValid() && wgAddr.IPv6Net.Contains(remoteIP)
if !inV4 && !inV6 {
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}