mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 08:52:07 -04:00
Compare commits
8 Commits
feature/us
...
improve-us
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
be434e1eb2 | ||
|
|
10da236dae | ||
|
|
ffac18409e | ||
|
|
3098f48b25 | ||
|
|
7f023ce801 | ||
|
|
e361126515 | ||
|
|
95213f7157 | ||
|
|
2e0e3a3601 |
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
$(GOLANGCI_LINT):
|
||||
@echo "Installing golangci-lint..."
|
||||
@mkdir -p ./bin
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||
|
||||
# Lint only changed files (fast, for pre-push)
|
||||
lint: $(GOLANGCI_LINT)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -26,6 +28,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
types "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -68,7 +71,30 @@ type Client struct {
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
|
||||
stateMu sync.RWMutex
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
cacheDir string
|
||||
}
|
||||
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
||||
c.stateMu.Lock()
|
||||
defer c.stateMu.Unlock()
|
||||
c.config = cfg
|
||||
c.cacheDir = cacheDir
|
||||
c.connectClient = cc
|
||||
}
|
||||
|
||||
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.config, c.cacheDir, c.connectClient
|
||||
}
|
||||
|
||||
func (c *Client) getConnectClient() *internal.ConnectClient {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.connectClient
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
cacheDir := platformFiles.CacheDir()
|
||||
|
||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
cacheDir := platformFiles.CacheDir()
|
||||
|
||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
|
||||
}
|
||||
|
||||
func (c *Client) RenewTun(fd int) error {
|
||||
if c.connectClient == nil {
|
||||
cc := c.getConnectClient()
|
||||
if cc == nil {
|
||||
return fmt.Errorf("engine not running")
|
||||
}
|
||||
|
||||
e := c.connectClient.Engine()
|
||||
e := cc.Engine()
|
||||
if e == nil {
|
||||
return fmt.Errorf("engine not initialized")
|
||||
}
|
||||
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
|
||||
return e.RenewTun(fd)
|
||||
}
|
||||
|
||||
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
||||
// It works both with and without a running engine.
|
||||
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
||||
cfg, cacheDir, cc := c.stateSnapshot()
|
||||
|
||||
// If the engine hasn't been started, load config from disk
|
||||
if cfg == nil {
|
||||
var err error
|
||||
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: platformFiles.ConfigurationFilePath(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
cacheDir = platformFiles.CacheDir()
|
||||
}
|
||||
|
||||
deps := debug.GeneratorDependencies{
|
||||
InternalConfig: cfg,
|
||||
StatusRecorder: c.recorder,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
|
||||
if cc != nil {
|
||||
resp, err := cc.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
deps.SyncResponse = resp
|
||||
|
||||
if e := cc.Engine(); e != nil {
|
||||
if cm := e.GetClientMetrics(); cm != nil {
|
||||
deps.ClientMetrics = cm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
deps,
|
||||
debug.BundleConfig{
|
||||
Anonymize: anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle uploaded with key %s", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
}
|
||||
|
||||
func (c *Client) Networks() *NetworkArray {
|
||||
if c.connectClient == nil {
|
||||
cc := c.getConnectClient()
|
||||
if cc == nil {
|
||||
log.Error("not connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := c.connectClient.Engine()
|
||||
engine := cc.Engine()
|
||||
if engine == nil {
|
||||
log.Error("could not get engine")
|
||||
return nil
|
||||
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
|
||||
}
|
||||
|
||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||
client := c.connectClient
|
||||
client := c.getConnectClient()
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
@@ -7,4 +7,5 @@ package android
|
||||
type PlatformFiles interface {
|
||||
ConfigurationFilePath() string
|
||||
StateFilePath() string
|
||||
CacheDir() string
|
||||
}
|
||||
|
||||
125
client/firewall/uspfilter/conntrack/cap_test.go
Normal file
125
client/firewall/uspfilter/conntrack/cap_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "4")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 4, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 4,
|
||||
"TCP table must not exceed the configured cap")
|
||||
require.Greater(t, len(tracker.connections), 0,
|
||||
"some entries must remain after eviction")
|
||||
|
||||
// The most recently admitted flow must be present: eviction must make
|
||||
// room for new entries, not silently drop them.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
|
||||
"newest TCP flow must be admitted after eviction")
|
||||
// A pre-cap flow must have been evicted to fit the last one.
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
|
||||
"oldest TCP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "3")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
// Fill to cap with 3 live connections.
|
||||
for i := 0; i < 3; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.Len(t, tracker.connections, 3)
|
||||
|
||||
// Tombstone one by sending RST through IsValidInbound.
|
||||
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
|
||||
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
|
||||
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
|
||||
|
||||
// Another live connection forces eviction. The tombstone must go first.
|
||||
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
|
||||
|
||||
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
|
||||
require.False(t, tombstonedStillPresent,
|
||||
"tombstoned entry should be evicted before live entries")
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
|
||||
// Both live pre-cap entries must survive: eviction must prefer the
|
||||
// tombstone, not just satisfy the size bound by dropping any entry.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
}
|
||||
|
||||
func TestUDPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvUDPMaxEntries, "5")
|
||||
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 5, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 5)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
|
||||
"newest UDP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
|
||||
"oldest UDP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestICMPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvICMPMaxEntries, "3")
|
||||
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 3, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
|
||||
for i := 0; i < 8; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
|
||||
"newest ICMP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
|
||||
"oldest ICMP flow should have been evicted")
|
||||
}
|
||||
@@ -3,14 +3,61 @@ package conntrack
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// evictSampleSize bounds how many map entries we scan per eviction call.
|
||||
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
|
||||
// heuristic is good enough for a conntrack table that only overflows under
|
||||
// abuse.
|
||||
const evictSampleSize = 8
|
||||
|
||||
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
|
||||
// def on empty or invalid; logs a warning on invalid.
|
||||
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
d, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
}
|
||||
if d <= 0 {
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
|
||||
// invalid, or non-positive. Logs a warning on invalid input.
|
||||
func envInt(logger *nblog.Logger, name string, def int) int {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
switch {
|
||||
case err != nil:
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
case n <= 0:
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
FlowId uuid.UUID
|
||||
|
||||
11
client/firewall/uspfilter/conntrack/defaults_desktop.go
Normal file
11
client/firewall/uspfilter/conntrack/defaults_desktop.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on desktop/server platforms. These mirror
|
||||
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
|
||||
const (
|
||||
DefaultMaxTCPEntries = 65536
|
||||
DefaultMaxUDPEntries = 16384
|
||||
DefaultMaxICMPEntries = 2048
|
||||
)
|
||||
13
client/firewall/uspfilter/conntrack/defaults_mobile.go
Normal file
13
client/firewall/uspfilter/conntrack/defaults_mobile.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build ios || android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on mobile platforms. iOS network extensions
|
||||
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
|
||||
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
|
||||
// is ~200 B plus map overhead).
|
||||
const (
|
||||
DefaultMaxTCPEntries = 4096
|
||||
DefaultMaxUDPEntries = 2048
|
||||
DefaultMaxICMPEntries = 512
|
||||
)
|
||||
@@ -44,6 +44,9 @@ type ICMPConnTrack struct {
|
||||
ICMPCode uint8
|
||||
}
|
||||
|
||||
// EnvICMPMaxEntries caps the ICMP conntrack table size.
|
||||
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
@@ -52,6 +55,7 @@ type ICMPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -135,6 +139,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -221,7 +226,9 @@ func (t *ICMPTracker) track(
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
@@ -240,10 +247,15 @@ func (t *ICMPTracker) track(
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
@@ -286,6 +298,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *ICMPTracker) evictOneLocked() {
|
||||
var candKey ICMPConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
@@ -294,8 +334,10 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,27 @@ const (
|
||||
TCPHandshakeTimeout = 60 * time.Second
|
||||
// TCPCleanupInterval is how often we check for stale connections
|
||||
TCPCleanupInterval = 5 * time.Minute
|
||||
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
|
||||
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
|
||||
FinWaitTimeout = 60 * time.Second
|
||||
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
|
||||
// holding CloseWait longer than this should bump the env var.
|
||||
CloseWaitTimeout = 60 * time.Second
|
||||
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
|
||||
LastAckTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Env vars to override per-state teardown timeouts. Values parsed by
|
||||
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
|
||||
// defaults above with a warning.
|
||||
const (
|
||||
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
|
||||
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
|
||||
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
|
||||
|
||||
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
|
||||
// (tombstones first) are evicted when the cap is reached.
|
||||
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
@@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() {
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
finWaitTimeout time.Duration
|
||||
closeWaitTimeout time.Duration
|
||||
lastAckTimeout time.Duration
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
@@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
|
||||
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
|
||||
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
|
||||
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
@@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
|
||||
// create spurious conntrack entries. Not mandated by RFC 9293 but a
|
||||
// common hardening (Linux netfilter/nftables rejects these too).
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
@@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
|
||||
// iteration order is randomized), preferring a tombstone. If no tombstone
|
||||
// found in the sample, evicts the oldest among the sampled entries. O(1)
|
||||
// worst case — cheap enough to run on every insert at cap during abuse.
|
||||
func (t *TCPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
if c.IsTombstone() {
|
||||
delete(t.connections, k)
|
||||
return
|
||||
}
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
// TypeEnd is already emitted at the state transition to
|
||||
// TimeWait and when a connection is tombstoned. Only emit
|
||||
// here when we're reaping a still-active flow.
|
||||
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
@@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject illegal flag combinations regardless of state. These never belong
|
||||
// to a legitimate flow and must not advance or tear down state.
|
||||
if !isValidFlagCombination(flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
// updateState updates the TCP connection state based on flags.
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
// Malformed flag combinations must not refresh lastSeen or drive state,
|
||||
// otherwise spoofed packets keep a dead flow alive past its timeout.
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
|
||||
// cannot apply the RFC 5961 in-window RST check, so we conservatively
|
||||
// reject RSTs that the spec would accept (TIME-WAIT with in-window
|
||||
// SEQ, SynSent from same direction as own SYN, etc.).
|
||||
t.handleRst(key, conn, currentState, packetDir)
|
||||
return
|
||||
}
|
||||
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
newState := nextState(currentState, conn.Direction, packetDir, flags)
|
||||
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
|
||||
return
|
||||
}
|
||||
t.onTransition(key, conn, currentState, newState, packetDir)
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
|
||||
// from the SYN direction are ignored; otherwise the flow is tombstoned.
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
|
||||
// TimeWait exists to absorb late segments; don't let a late RST
|
||||
// tombstone the entry and break same-4-tuple reuse.
|
||||
if currentState == TCPStateTimeWait {
|
||||
return
|
||||
}
|
||||
// A RST from the same direction as the SYN cannot be a legitimate
|
||||
// response and must not tear down a half-open connection.
|
||||
if currentState == TCPStateSynSent && packetDir == conn.Direction {
|
||||
return
|
||||
}
|
||||
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
return
|
||||
}
|
||||
conn.SetTombstone()
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
// stateTransition describes one state's transition logic. It receives the
|
||||
// packet's flags plus whether the packet direction matches the connection's
|
||||
// origin direction (same=true means same side as the SYN initiator). Return 0
|
||||
// for no transition.
|
||||
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
}
|
||||
// stateTable maps each state to its transition function. Centralized here so
|
||||
// nextState stays trivial and each rule is easy to read in isolation.
|
||||
var stateTable = map[TCPState]stateTransition{
|
||||
TCPStateNew: transNew,
|
||||
TCPStateSynSent: transSynSent,
|
||||
TCPStateSynReceived: transSynReceived,
|
||||
TCPStateEstablished: transEstablished,
|
||||
TCPStateFinWait1: transFinWait1,
|
||||
TCPStateFinWait2: transFinWait2,
|
||||
TCPStateClosing: transClosing,
|
||||
TCPStateCloseWait: transCloseWait,
|
||||
TCPStateLastAck: transLastAck,
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
// nextState returns the target TCP state for the given current state and
|
||||
// packet, or 0 if the packet does not trigger a transition.
|
||||
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
|
||||
fn, ok := stateTable[currentState]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return fn(flags, connDir, packetDir == connDir)
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if connDir == nftypes.Egress {
|
||||
return TCPStateSynSent
|
||||
}
|
||||
return TCPStateSynReceived
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if same {
|
||||
return TCPStateSynReceived // simultaneous open
|
||||
}
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin == 0 {
|
||||
return 0
|
||||
}
|
||||
if same {
|
||||
return TCPStateFinWait1
|
||||
}
|
||||
return TCPStateCloseWait
|
||||
}
|
||||
|
||||
// transFinWait1 handles the active-close peer response. A FIN carrying our
|
||||
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
|
||||
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
|
||||
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
|
||||
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if same {
|
||||
return 0
|
||||
}
|
||||
if flags&TCPFin != 0 && flags&TCPAck != 0 {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
switch {
|
||||
case flags&TCPFin != 0:
|
||||
return TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
return TCPStateFinWait2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
|
||||
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// transClosing completes a simultaneous close on the peer's ACK.
|
||||
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
|
||||
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && same {
|
||||
return TCPStateLastAck
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
|
||||
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateClosed
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// onTransition handles logging and flow-event emission after a successful
|
||||
// state transition. TimeWait and Closed are terminal for flow accounting.
|
||||
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
|
||||
traceOn := t.logger.Enabled(nblog.LevelTrace)
|
||||
if traceOn {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
|
||||
}
|
||||
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
switch to {
|
||||
case TCPStateTimeWait:
|
||||
if traceOn {
|
||||
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
if traceOn {
|
||||
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current
|
||||
// connection state. Caller must have already verified the flag combination is
|
||||
// legal via isValidFlagCombination.
|
||||
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
@@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() {
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
|
||||
timeout = t.finWaitTimeout
|
||||
case TCPStateCloseWait:
|
||||
timeout = t.closeWaitTimeout
|
||||
case TCPStateLastAck:
|
||||
timeout = t.lastAckTimeout
|
||||
default:
|
||||
// SynSent / SynReceived / New
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
|
||||
// event already handled by state change
|
||||
if currentState != TCPStateTimeWait {
|
||||
|
||||
100
client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go
Normal file
100
client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// RST hygiene tests: the tracker currently closes the flow on any RST that
|
||||
// matches the 4-tuple, regardless of direction or state. These tests cover
|
||||
// the minimum checks we want (no SEQ tracking).
|
||||
|
||||
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
|
||||
// cannot be a legitimate response. It must not close the connection.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState(),
|
||||
"RST in same direction as SYN must not close connection")
|
||||
require.False(t, conn.IsTombstone())
|
||||
}
|
||||
|
||||
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Drive to TIME-WAIT via active close.
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
|
||||
|
||||
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
|
||||
// exists to absorb late segments).
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState(),
|
||||
"RST in TIME-WAIT must not transition state")
|
||||
require.False(t, conn.IsTombstone(),
|
||||
"RST in TIME-WAIT must not tombstone the entry")
|
||||
}
|
||||
|
||||
func TestTCPIllegalFlagCombos(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
|
||||
// Illegal combos must be rejected and must not change state.
|
||||
combos := []struct {
|
||||
name string
|
||||
flags uint8
|
||||
}{
|
||||
{"SYN+RST", TCPSyn | TCPRst},
|
||||
{"FIN+RST", TCPFin | TCPRst},
|
||||
{"SYN+FIN", TCPSyn | TCPFin},
|
||||
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
|
||||
}
|
||||
|
||||
for _, c := range combos {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
before := conn.GetState()
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
|
||||
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
|
||||
require.Equal(t, before, conn.GetState(),
|
||||
"illegal flag combo must not change state")
|
||||
require.False(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
}
|
||||
235
client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go
Normal file
235
client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// These tests exercise cases where the TCP state machine currently advances
|
||||
// on retransmitted or wrong-direction segments and tears the flow down
|
||||
// prematurely. They are expected to fail until the direction checks are added.
|
||||
|
||||
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer sends FIN -> CloseWait (our app has not yet closed).
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
|
||||
// sent our FIN yet, so state must remain CloseWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted peer FIN must still be accepted")
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState(),
|
||||
"retransmitted peer FIN must not advance CloseWait to LastAck")
|
||||
|
||||
// Our app finally closes -> LastAck.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Peer ACK closes.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// We initiate close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Stray retransmit of our own FIN (same direction as originator) must
|
||||
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState(),
|
||||
"own FIN retransmit must not advance FinWait2 to TimeWait")
|
||||
|
||||
// Peer FIN -> TimeWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPLastAckDirectionCheck(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must NOT close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState(),
|
||||
"own ACK retransmit in LastAck must not transition to Closed")
|
||||
|
||||
// Peer's ACK -> Closed.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must not advance.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState(),
|
||||
"own ACK in FinWait1 must not advance to FinWait2")
|
||||
}
|
||||
|
||||
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
|
||||
// Verify cleanup reaps entries in each teardown state at the configured
|
||||
// per-state timeout, not at the single handshake timeout.
|
||||
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
|
||||
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
|
||||
t.Setenv(EnvTCPLastAckTimeout, "30ms")
|
||||
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Drives a connection to the target state, forces its lastSeen well
|
||||
// beyond the configured timeout, runs cleanup, and asserts reaping.
|
||||
cases := []struct {
|
||||
name string
|
||||
// drive takes a fresh tracker and returns the conn key after
|
||||
// transitioning the flow into the intended teardown state.
|
||||
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
|
||||
}{
|
||||
{
|
||||
name: "FinWait1",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FinWait2",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWait",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "LastAck",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use a unique source port per subtest so nothing aliases.
|
||||
port := uint16(12345)
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
|
||||
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
|
||||
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
port++
|
||||
key, wantState := c.drive(t, tracker, srcIP, port)
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, wantState, conn.GetState())
|
||||
|
||||
// Age the entry past the largest per-state timeout.
|
||||
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
|
||||
tracker.cleanup()
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "%s entry should be reaped", c.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
|
||||
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
|
||||
// teardown states, which some stacks emit during close.
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer FIN -> CloseWait.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
|
||||
// Peer pushes trailing data + FIN|PSH|ACK (legal).
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
|
||||
"FIN|PSH|ACK in CloseWait must be accepted")
|
||||
|
||||
// Bare ACK keepalive from peer in CloseWait must be accepted.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
|
||||
"bare ACK in CloseWait must be accepted")
|
||||
}
|
||||
@@ -17,6 +17,9 @@ const (
|
||||
DefaultUDPTimeout = 30 * time.Second
|
||||
// UDPCleanupInterval is how often we check for stale connections
|
||||
UDPCleanupInterval = 15 * time.Second
|
||||
|
||||
// EnvUDPMaxEntries caps the UDP conntrack table size.
|
||||
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
|
||||
)
|
||||
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
@@ -34,6 +37,7 @@ type UDPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
@@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
return true
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *UDPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
@@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,7 +709,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -808,7 +810,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -931,8 +935,10 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -974,8 +980,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1025,8 +1033,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1043,8 +1053,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1126,7 +1138,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
@@ -92,8 +93,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
@@ -116,8 +119,10 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
@@ -198,13 +203,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -16,7 +13,9 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
@@ -38,7 +37,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -61,64 +62,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
success = true
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
}
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
// netrelay.Relay copies bidirectionally with proper half-close propagation
|
||||
// and fully closes both conns before returning.
|
||||
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
|
||||
Logger: f.logger,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
// Close the netstack endpoint after both conns are drained.
|
||||
ep.Close()
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
@@ -127,7 +86,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
}
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
@@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
success = true
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
}
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
@@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
}
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
|
||||
@@ -53,16 +53,17 @@ var levelStrings = map[Level]string{
|
||||
}
|
||||
|
||||
type logMessage struct {
|
||||
level Level
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
level Level
|
||||
argCount uint8
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
// Enabled reports whether the given level is currently logged. Callers on the
|
||||
// hot path should guard log sites with this to avoid boxing arguments into
|
||||
// any when the level is off.
|
||||
func (l *Logger) Enabled(level Level) bool {
|
||||
return l.level.Load() >= uint32(level)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) {
|
||||
func (l *Logger) Error1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) {
|
||||
func (l *Logger) Error2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
|
||||
func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
|
||||
// to avoid allocating an args slice on the fast path when the arg count is
|
||||
// known (0-3). Args beyond 3 land on the general variadic path; callers on
|
||||
// the hot path should prefer DebugN for known counts.
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
if l.level.Load() < uint32(LevelDebug) {
|
||||
return
|
||||
}
|
||||
switch len(args) {
|
||||
case 0:
|
||||
l.Debug(format)
|
||||
case 1:
|
||||
l.Debug1(format, args[0])
|
||||
case 2:
|
||||
l.Debug2(format, args[0], args[1])
|
||||
case 3:
|
||||
l.Debug3(format, args[0], args[1], args[2])
|
||||
default:
|
||||
l.sendVariadic(LevelDebug, format, args)
|
||||
}
|
||||
}
|
||||
|
||||
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
|
||||
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
|
||||
// beyond the 8-arg slot limit are dropped so callers don't produce silently
|
||||
// empty log lines via uint8 wraparound in argCount.
|
||||
func (l *Logger) sendVariadic(level Level, format string, args []any) {
|
||||
const maxArgs = 8
|
||||
n := len(args)
|
||||
if n > maxArgs {
|
||||
n = maxArgs
|
||||
}
|
||||
msg := logMessage{level: level, argCount: uint8(n), format: format}
|
||||
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
|
||||
for i := 0; i < n; i++ {
|
||||
*slots[i] = args[i]
|
||||
}
|
||||
select {
|
||||
case l.msgChannel <- msg:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
|
||||
func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = append(*buf, levelStrings[msg.level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Count non-nil arguments for switch
|
||||
argCount := 0
|
||||
if msg.arg1 != nil {
|
||||
argCount++
|
||||
if msg.arg2 != nil {
|
||||
argCount++
|
||||
if msg.arg3 != nil {
|
||||
argCount++
|
||||
if msg.arg4 != nil {
|
||||
argCount++
|
||||
if msg.arg5 != nil {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatted string
|
||||
switch argCount {
|
||||
switch msg.argCount {
|
||||
case 0:
|
||||
formatted = msg.format
|
||||
case 1:
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
@@ -242,11 +243,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -264,11 +269,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -521,7 +530,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
|
||||
@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
// Close closes the tunnel interface
|
||||
func (w *WGIface) Close() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||
}
|
||||
|
||||
if err := w.tun.Close(); err != nil {
|
||||
// Release w.mu before calling w.tun.Close(): the underlying
|
||||
// wireguard-go device.Close() waits for its send/receive goroutines
|
||||
// to drain. Some of those goroutines re-enter WGIface methods that
|
||||
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
|
||||
// holding the mutex here would deadlock the shutdown path.
|
||||
tun := w.tun
|
||||
w.mu.Unlock()
|
||||
|
||||
if err := tun.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
|
||||
113
client/iface/iface_close_test.go
Normal file
113
client/iface/iface_close_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
//go:build !android
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// fakeTunDevice implements WGTunDevice and lets the test control when
|
||||
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
|
||||
// until its goroutines drain. Some of those goroutines (e.g. the packet
|
||||
// filter DNS hook in client/internal/dns) call back into WGIface, so if
|
||||
// WGIface.Close() held w.mu across tun.Close() the shutdown would
|
||||
// deadlock.
|
||||
type fakeTunDevice struct {
|
||||
closeStarted chan struct{}
|
||||
unblockClose chan struct{}
|
||||
}
|
||||
|
||||
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
|
||||
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
|
||||
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
|
||||
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
|
||||
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
|
||||
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
|
||||
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
|
||||
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
|
||||
|
||||
func (f *fakeTunDevice) Close() error {
|
||||
close(f.closeStarted)
|
||||
<-f.unblockClose
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeProxyFactory struct{}
|
||||
|
||||
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
|
||||
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
|
||||
func (fakeProxyFactory) Free() error { return nil }
|
||||
|
||||
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
|
||||
// that surfaces as a macOS test-timeout in
|
||||
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
|
||||
// while waiting for the wireguard-go device goroutines to finish, and
|
||||
// one of those goroutines (the DNS filter hook) calls back into
|
||||
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
|
||||
// the lock before tun.Close() returns control.
|
||||
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
|
||||
tun := &fakeTunDevice{
|
||||
closeStarted: make(chan struct{}),
|
||||
unblockClose: make(chan struct{}),
|
||||
}
|
||||
w := &WGIface{
|
||||
tun: tun,
|
||||
wgProxyFactory: fakeProxyFactory{},
|
||||
}
|
||||
|
||||
closeDone := make(chan error, 1)
|
||||
go func() {
|
||||
closeDone <- w.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tun.closeStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
close(tun.unblockClose)
|
||||
t.Fatal("tun.Close() was never invoked")
|
||||
}
|
||||
|
||||
// Simulate the WireGuard read goroutine calling back into WGIface
|
||||
// via the packet filter's DNS hook. If Close() still held w.mu
|
||||
// during tun.Close(), this would block until the test timeout.
|
||||
getDeviceDone := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = w.GetDevice()
|
||||
close(getDeviceDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-getDeviceDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
close(tun.unblockClose)
|
||||
wg.Wait()
|
||||
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
|
||||
}
|
||||
|
||||
close(tun.unblockClose)
|
||||
select {
|
||||
case <-closeDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
|
||||
}
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
u.addrCache.Store(addr.String(), isRouted)
|
||||
if isRouted {
|
||||
// Extra log, as the error only shows up with ICE logging enabled
|
||||
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
stateFilePath string,
|
||||
cacheDir string,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
@@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
engineConfig.TempDir = mobileDependency.TempDir
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -31,7 +30,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -234,6 +232,7 @@ type BundleGenerator struct {
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
tempDir string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
@@ -256,6 +255,7 @@ type GeneratorDependencies struct {
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
ClientMetrics MetricsExporter
|
||||
@@ -275,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
@@ -287,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
|
||||
// Generate creates a debug bundle and returns the location.
|
||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create zip file: %w", err)
|
||||
}
|
||||
@@ -373,15 +374,8 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
||||
}
|
||||
}
|
||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||
log.Errorf("failed to add systemd logs: %v", err)
|
||||
if err := g.addPlatformLog(); err != nil {
|
||||
log.Errorf("failed to add logs to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addUpdateLogs(); err != nil {
|
||||
|
||||
41
client/internal/debug/debug_android.go
Normal file
41
client/internal/debug/debug_android.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
cmd := exec.Command("/system/bin/logcat", "-d")
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("logcat stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start logcat: %w", err)
|
||||
}
|
||||
|
||||
var logReader io.Reader = stdout
|
||||
if g.anonymize {
|
||||
var pw *io.PipeWriter
|
||||
logReader, pw = io.Pipe()
|
||||
go anonymizeLog(stdout, pw, g.anonymizer)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
||||
return fmt.Errorf("add logcat to zip: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("wait logcat: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("added logcat output to debug bundle")
|
||||
return nil
|
||||
}
|
||||
25
client/internal/debug/debug_nonandroid.go
Normal file
25
client/internal/debug/debug_nonandroid.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -140,6 +140,7 @@ type EngineConfig struct {
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
LogPath string
|
||||
TempDir string
|
||||
}
|
||||
|
||||
// EngineServices holds the external service dependencies required by the Engine.
|
||||
@@ -1095,6 +1096,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
TempDir: e.config.TempDir,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
|
||||
@@ -22,4 +22,8 @@ type MobileDependency struct {
|
||||
DnsManager dns.IosDnsManager
|
||||
FileDescriptor int32
|
||||
StateFilePath string
|
||||
|
||||
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
|
||||
// On Android, this should be set to the app's cache directory.
|
||||
TempDir string
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
|
||||
|
||||
package systemops
|
||||
|
||||
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
|
||||
// always fall through to the ref-counter exclusion-route path; these stubs
|
||||
// exist only so systemops_unix.go compiles.
|
||||
func (r *SysOps) setupAdvancedRouting() error { return nil }
|
||||
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
|
||||
func (r *SysOps) flushPlatformExtras() error { return nil }
|
||||
241
client/internal/routemanager/systemops/systemops_darwin.go
Normal file
241
client/internal/routemanager/systemops/systemops_darwin.go
Normal file
@@ -0,0 +1,241 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// scopedRouteBudget bounds retries for the scoped default route. Installing or
|
||||
// deleting it matters enough that we're willing to spend longer waiting for the
|
||||
// kernel reply than for per-prefix exclusion routes.
|
||||
const scopedRouteBudget = 5 * time.Second
|
||||
|
||||
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
|
||||
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
|
||||
// resolve gateway'd destinations while the VPN's split default owns the
|
||||
// unscoped table.
|
||||
//
|
||||
// Timing note: this runs during routeManager.Init, which happens before the
|
||||
// VPN interface is created and before any peer routes propagate. The initial
|
||||
// mgmt / signal / relay TCP dials always fire before this runs, so those
|
||||
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
|
||||
// lookup, which at that point correctly picks the physical default. Those
|
||||
// already-established TCP flows keep their originally-selected interface for
|
||||
// their lifetime on Darwin because the kernel caches the egress route
|
||||
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
|
||||
// afterwards does not migrate them since the original en0 default stays in
|
||||
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
|
||||
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
|
||||
func (r *SysOps) setupAdvancedRouting() error {
|
||||
// Drop any previously-cached egress interface before reinstalling. On a
|
||||
// refresh, a family that no longer resolves would otherwise keep the stale
|
||||
// binding, causing new sockets to scope to an interface without a matching
|
||||
// scoped default.
|
||||
nbnet.ClearBoundInterfaces()
|
||||
|
||||
if err := r.flushScopedDefaults(); err != nil {
|
||||
log.Warnf("flush residual scoped defaults: %v", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
installed := 0
|
||||
|
||||
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
||||
ok, err := r.installScopedDefaultFor(unspec)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
continue
|
||||
}
|
||||
if ok {
|
||||
installed++
|
||||
}
|
||||
}
|
||||
|
||||
if installed == 0 && merr != nil {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if merr != nil {
|
||||
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// installScopedDefaultFor resolves the physical default nexthop for the given
|
||||
// address family, installs a scoped default via it, and caches the iface for
|
||||
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
|
||||
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||
nexthop, err := GetNextHop(unspec)
|
||||
if err != nil {
|
||||
if errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
|
||||
}
|
||||
if nexthop.Intf == nil {
|
||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||
}
|
||||
|
||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||
}
|
||||
|
||||
af := unix.AF_INET
|
||||
if unspec.Is6() {
|
||||
af = unix.AF_INET6
|
||||
}
|
||||
nbnet.SetBoundInterface(af, nexthop.Intf)
|
||||
via := "point-to-point"
|
||||
if nexthop.IP.IsValid() {
|
||||
via = nexthop.IP.String()
|
||||
}
|
||||
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupAdvancedRouting() error {
|
||||
nbnet.ClearBoundInterfaces()
|
||||
return r.flushScopedDefaults()
|
||||
}
|
||||
|
||||
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
|
||||
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
|
||||
// removed on the next boot regardless of whether a profile is brought up.
|
||||
func (r *SysOps) flushPlatformExtras() error {
|
||||
return r.flushScopedDefaults()
|
||||
}
|
||||
|
||||
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
|
||||
// Safe to call at startup to clear residual entries from a prior session.
|
||||
func (r *SysOps) flushScopedDefaults() error {
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
removed := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
rtMsg, ok := msg.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||
continue
|
||||
}
|
||||
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := MsgToRoute(rtMsg)
|
||||
if err != nil {
|
||||
log.Debugf("skip scoped flush: %v", err)
|
||||
continue
|
||||
}
|
||||
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.deleteScopedRoute(rtMsg); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
|
||||
info.Dst, rtMsg.Index, err))
|
||||
continue
|
||||
}
|
||||
removed++
|
||||
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
log.Infof("flushed %d residual scoped default route(s)", removed)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
|
||||
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
|
||||
// Preserve identifying flags from the stored route (including RTF_GATEWAY
|
||||
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
|
||||
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
|
||||
del := &route.RouteMessage{
|
||||
Type: unix.RTM_DELETE,
|
||||
Flags: rtMsg.Flags & keep,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
Index: rtMsg.Index,
|
||||
Addrs: rtMsg.Addrs,
|
||||
}
|
||||
return r.writeRouteMessage(del, scopedRouteBudget)
|
||||
}
|
||||
|
||||
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
|
||||
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
|
||||
|
||||
msg := &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: flags,
|
||||
Version: unix.RTM_VERSION,
|
||||
ID: uintptr(os.Getpid()),
|
||||
Seq: r.getSeq(),
|
||||
Index: nexthop.Intf.Index,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
addrs := make([]route.Addr, numAddrs)
|
||||
|
||||
dst, err := addrToRouteAddr(unspec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build destination: %w", err)
|
||||
}
|
||||
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build netmask: %w", err)
|
||||
}
|
||||
addrs[unix.RTAX_DST] = dst
|
||||
addrs[unix.RTAX_NETMASK] = mask
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
msg.Flags |= unix.RTF_GATEWAY
|
||||
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
|
||||
if err != nil {
|
||||
return fmt.Errorf("build gateway: %w", err)
|
||||
}
|
||||
addrs[unix.RTAX_GATEWAY] = gw
|
||||
} else {
|
||||
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
||||
Index: nexthop.Intf.Index,
|
||||
Name: nexthop.Intf.Name,
|
||||
}
|
||||
}
|
||||
msg.Addrs = addrs
|
||||
|
||||
return r.writeRouteMessage(msg, scopedRouteBudget)
|
||||
}
|
||||
|
||||
func afOf(a netip.Addr) string {
|
||||
if a.Is4() {
|
||||
return "IPv4"
|
||||
}
|
||||
return "IPv6"
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
@@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
@@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
}
|
||||
|
||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
|
||||
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
|
||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
localRoutes, err := hasSeparateRouting()
|
||||
if nbnet.AdvancedRouting() {
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
localRoutes, err := GetRoutesFromTable()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrRoutingIsSeparate) {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
}
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||
return []DetailedRoute{}, nil
|
||||
|
||||
@@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
return nil, ErrRoutingIsSeparate
|
||||
}
|
||||
|
||||
func isOpErr(err error) bool {
|
||||
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
|
||||
@@ -48,10 +48,6 @@ func EnableIPForwarding() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
|
||||
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
||||
func GetIPRules() ([]IPRule, error) {
|
||||
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
||||
|
||||
@@ -25,6 +25,9 @@ import (
|
||||
|
||||
const (
|
||||
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||
|
||||
// routeBudget bounds retries for per-prefix exclusion route programming.
|
||||
routeBudget = 1 * time.Second
|
||||
)
|
||||
|
||||
var routeProtoFlag int
|
||||
@@ -41,26 +44,42 @@ func init() {
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return r.setupAdvancedRouting()
|
||||
}
|
||||
|
||||
log.Infof("Using legacy routing setup with ref counters")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return r.cleanupAdvancedRouting()
|
||||
}
|
||||
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
|
||||
// crashed prior session can't leave crud in the table.
|
||||
func (r *SysOps) FlushMarkedRoutes() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.flushPlatformExtras(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
|
||||
}
|
||||
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
flushedCount := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
@@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
||||
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||
}
|
||||
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build route message: %w", err)
|
||||
}
|
||||
|
||||
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
||||
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
|
||||
a := "add"
|
||||
if action == unix.RTM_DELETE {
|
||||
a = "remove"
|
||||
@@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
||||
operation := func() error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
|
||||
// kernel's matching reply, retrying transient failures until budget elapses.
|
||||
// Callers do not need to manage sockets or seq numbers themselves.
|
||||
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = budget
|
||||
|
||||
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
|
||||
}
|
||||
|
||||
func routeMessageRoundtrip(msg *route.RouteMessage) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("close routing socket: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("failed to close routing socket: %v", err)
|
||||
}()
|
||||
|
||||
tv := unix.Timeval{Sec: 1}
|
||||
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
|
||||
}
|
||||
|
||||
// AF_ROUTE is a broadcast channel: every route socket on the host sees
|
||||
// every RTM_* event. With concurrent route programming the default
|
||||
// per-socket queue overflows and our own reply gets dropped.
|
||||
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
|
||||
log.Debugf("set SO_RCVBUF on route socket: %v", err)
|
||||
}
|
||||
|
||||
bytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
|
||||
}
|
||||
|
||||
if _, err = unix.Write(fd, bytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
}
|
||||
return readRouteResponse(fd, msg.Type, msg.Seq)
|
||||
}
|
||||
|
||||
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
|
||||
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
|
||||
// broadcast channel: interface up/down, third-party route changes and neighbor
|
||||
// discovery events can all land between our write and read, so we must filter.
|
||||
func readRouteResponse(fd, wantType, wantSeq int) error {
|
||||
pid := int32(os.Getpid())
|
||||
resp := make([]byte, 2048)
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
// Transient: under concurrent pressure the kernel can drop our reply
|
||||
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
|
||||
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
|
||||
}
|
||||
n, err := unix.Read(fd, resp)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
|
||||
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
|
||||
continue
|
||||
}
|
||||
}()
|
||||
|
||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
||||
return backoff.Permanent(fmt.Errorf("read: %w", err))
|
||||
}
|
||||
|
||||
msgBytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
||||
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err = unix.Write(fd, msgBytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
|
||||
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
|
||||
// uniquely identifies our own reply among broadcast traffic.
|
||||
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
|
||||
continue
|
||||
}
|
||||
|
||||
respBuf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, respBuf)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
||||
if hdr.Errno != 0 {
|
||||
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return operation
|
||||
}
|
||||
|
||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||
@@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP | routeProtoFlag,
|
||||
Version: unix.RTM_VERSION,
|
||||
ID: uintptr(os.Getpid()),
|
||||
Seq: r.getSeq(),
|
||||
}
|
||||
|
||||
@@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
return nil
|
||||
}
|
||||
|
||||
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
if rtMsg.Errno != 0 {
|
||||
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||
if addr.Is4() {
|
||||
|
||||
5
client/net/dialer_init_darwin.go
Normal file
5
client/net/dialer_init_darwin.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = applyBoundIfToSocket
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package net
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment for Android
|
||||
func Init() {
|
||||
// No initialization needed on Android
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||
// Always returns true on Android since we cannot handle routes dynamically.
|
||||
func AdvancedRouting() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on Android
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on Android - not needed for Android VPN service
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on Android
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build windows
|
||||
//go:build (darwin && !ios) || windows
|
||||
|
||||
package net
|
||||
|
||||
@@ -24,17 +24,22 @@ func Init() {
|
||||
}
|
||||
|
||||
func checkAdvancedRoutingSupport() bool {
|
||||
var err error
|
||||
var legacyRouting bool
|
||||
legacyRouting := false
|
||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||
legacyRouting, err = strconv.ParseBool(val)
|
||||
parsed, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
|
||||
} else {
|
||||
legacyRouting = parsed
|
||||
}
|
||||
}
|
||||
|
||||
if legacyRouting || netstack.IsEnabled() {
|
||||
log.Info("advanced routing has been requested to be disabled")
|
||||
if legacyRouting {
|
||||
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
|
||||
return false
|
||||
}
|
||||
if netstack.IsEnabled() {
|
||||
log.Info("advanced routing disabled: netstack mode is enabled")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows && !android
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package net
|
||||
|
||||
|
||||
25
client/net/env_mobile.go
Normal file
25
client/net/env_mobile.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build ios || android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment for mobile platforms.
|
||||
func Init() {
|
||||
// no-op on mobile: routing scope is owned by the VPN extension.
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
|
||||
// owns the routing scope.
|
||||
func AdvancedRouting() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on mobile.
|
||||
func SetVPNInterfaceName(string) {
|
||||
// no-op on mobile: the VPN extension manages the interface.
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns an empty string on mobile.
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
5
client/net/listener_init_darwin.go
Normal file
5
client/net/listener_init_darwin.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = applyBoundIfToSocket
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows
|
||||
//go:build !linux && !windows && !darwin
|
||||
|
||||
package net
|
||||
|
||||
|
||||
160
client/net/net_darwin.go
Normal file
160
client/net/net_darwin.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
|
||||
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
|
||||
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
|
||||
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
|
||||
// dispatched by socket domain (AF_INET only) not by inp_vflag.
|
||||
|
||||
// boundIface holds the physical interface chosen at routing setup time. Sockets
|
||||
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
|
||||
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
|
||||
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
|
||||
// following the VPN's split default.
|
||||
var (
|
||||
boundIfaceMu sync.RWMutex
|
||||
boundIface4 *net.Interface
|
||||
boundIface6 *net.Interface
|
||||
)
|
||||
|
||||
// SetBoundInterface records the egress interface for an address family. Called
|
||||
// by the routemanager after a scoped default route has been installed.
|
||||
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
|
||||
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
|
||||
func SetBoundInterface(af int, iface *net.Interface) {
|
||||
if iface == nil {
|
||||
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
|
||||
return
|
||||
}
|
||||
boundIfaceMu.Lock()
|
||||
defer boundIfaceMu.Unlock()
|
||||
switch af {
|
||||
case unix.AF_INET:
|
||||
boundIface4 = iface
|
||||
case unix.AF_INET6:
|
||||
boundIface6 = iface
|
||||
default:
|
||||
log.Warnf("SetBoundInterface: unsupported address family %d", af)
|
||||
}
|
||||
}
|
||||
|
||||
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
|
||||
// routemanager during cleanup.
|
||||
func ClearBoundInterfaces() {
|
||||
boundIfaceMu.Lock()
|
||||
defer boundIfaceMu.Unlock()
|
||||
boundIface4 = nil
|
||||
boundIface6 = nil
|
||||
}
|
||||
|
||||
// boundInterfaceFor returns the cached egress interface for a socket's address
|
||||
// family, falling back to the other family if the preferred slot is empty.
|
||||
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
|
||||
// either setsockopt scopes the socket; preferring same-family still matters
|
||||
// when v4 and v6 defaults egress different NICs.
|
||||
func boundInterfaceFor(network, address string) *net.Interface {
|
||||
if iface := zoneInterface(address); iface != nil {
|
||||
return iface
|
||||
}
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
defer boundIfaceMu.RUnlock()
|
||||
|
||||
primary, secondary := boundIface4, boundIface6
|
||||
if isV6Network(network) {
|
||||
primary, secondary = boundIface6, boundIface4
|
||||
}
|
||||
if primary != nil {
|
||||
return primary
|
||||
}
|
||||
return secondary
|
||||
}
|
||||
|
||||
func isV6Network(network string) bool {
|
||||
return strings.HasSuffix(network, "6")
|
||||
}
|
||||
|
||||
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
|
||||
func zoneInterface(address string) *net.Interface {
|
||||
if address == "" {
|
||||
return nil
|
||||
}
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil {
|
||||
a, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
addr = netip.AddrPortFrom(a, 0)
|
||||
}
|
||||
zone := addr.Addr().Zone()
|
||||
if zone == "" {
|
||||
return nil
|
||||
}
|
||||
if iface, err := net.InterfaceByName(zone); err == nil {
|
||||
return iface
|
||||
}
|
||||
if idx, err := strconv.Atoi(zone); err == nil {
|
||||
if iface, err := net.InterfaceByIndex(idx); err == nil {
|
||||
return iface
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyBoundIfToSocket binds the socket to the cached physical egress interface
|
||||
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
|
||||
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
iface := boundInterfaceFor(network, address)
|
||||
if iface == nil {
|
||||
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
|
||||
return nil
|
||||
}
|
||||
|
||||
isV6 := isV6Network(network)
|
||||
var controlErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
if isV6 {
|
||||
controlErr = setIPv6BoundIf(fd, iface)
|
||||
} else {
|
||||
controlErr = setIPv4BoundIf(fd, iface)
|
||||
}
|
||||
if controlErr == nil {
|
||||
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
@@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
||||
}
|
||||
|
||||
// clean up any remaining routes independently of the state file
|
||||
if !nbnet.AdvancedRouting() {
|
||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||
}
|
||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
go c.handleLocalForward(ctx, localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local port forwarding: close local connection: %v", err)
|
||||
@@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
}
|
||||
}()
|
||||
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
@@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan := <-channelRequests:
|
||||
case newChan, ok := <-channelRequests:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
|
||||
// channel open indefinitely; the relay itself runs under the outer ctx.
|
||||
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
|
||||
var dialer net.Dialer
|
||||
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
|
||||
cancelDial()
|
||||
if err != nil {
|
||||
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
|
||||
}
|
||||
}()
|
||||
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
|
||||
@@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
return addresses
|
||||
}
|
||||
|
||||
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
|
||||
// It waits for both directions to complete before returning.
|
||||
// The caller is responsible for closing the connections.
|
||||
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
|
||||
// It waits for both directions to complete or for context cancellation before returning.
|
||||
// Both connections are closed when the function returns.
|
||||
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
}
|
||||
|
||||
@@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
hostList := strings.Join(deduplicatedPatterns, ",")
|
||||
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) {
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_MatchHostFormat(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
peers := []PeerSSHInfo{
|
||||
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
|
||||
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
configStr := string(content)
|
||||
|
||||
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
|
||||
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
|
||||
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
|
||||
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
|
||||
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
|
||||
"should use Match host with comma-separated patterns")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
// Set force environment variable
|
||||
t.Setenv(EnvForceSSHConfig, "true")
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
@@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
const privilegedPortThreshold = 1024
|
||||
@@ -356,7 +356,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
return
|
||||
}
|
||||
|
||||
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
|
||||
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -52,6 +54,10 @@ const (
|
||||
DefaultJWTMaxTokenAge = 10 * 60
|
||||
)
|
||||
|
||||
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
|
||||
// the forwarded destination before rejecting the SSH channel.
|
||||
const directTCPIPDialTimeout = 30 * time.Second
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
@@ -891,5 +897,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
|
||||
}
|
||||
|
||||
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
|
||||
// DirectTCPIPHandler. The upstream handler closes both sides on the first
|
||||
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
|
||||
// own terms.
|
||||
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
|
||||
dest := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
|
||||
dconn, err := dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
_ = dconn.Close()
|
||||
return
|
||||
}
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
|
||||
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
@@ -145,59 +144,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
||||
return v
|
||||
}
|
||||
|
||||
func networkAddresses() ([]NetworkAddress, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, address := range addrs {
|
||||
ipNet, ok := address.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipNet.IP.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
netAddr := NetworkAddress{
|
||||
NetIP: netip.MustParsePrefix(ipNet.String()),
|
||||
Mac: iface.HardwareAddr.String(),
|
||||
}
|
||||
|
||||
if isDuplicated(netAddresses, netAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
netAddresses = append(netAddresses, netAddr)
|
||||
}
|
||||
}
|
||||
return netAddresses, nil
|
||||
}
|
||||
|
||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
for _, duplicated := range addresses {
|
||||
if duplicated.NetIP == addr.NetIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||
|
||||
@@ -2,6 +2,8 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -42,6 +44,66 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
// networkAddresses returns the list of network addresses on iOS.
|
||||
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
|
||||
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
|
||||
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
|
||||
// check that other platforms use and filter out link-local addresses as they
|
||||
// are not useful for posture checks.
|
||||
func networkAddresses() ([]NetworkAddress, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, address := range addrs {
|
||||
netAddr, ok := toNetworkAddress(address)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDuplicated(netAddresses, netAddr) {
|
||||
continue
|
||||
}
|
||||
netAddresses = append(netAddresses, netAddr)
|
||||
}
|
||||
}
|
||||
return netAddresses, nil
|
||||
}
|
||||
|
||||
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
|
||||
ipNet, ok := address.(*net.IPNet)
|
||||
if !ok {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||
if err != nil {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
return NetworkAddress{NetIP: prefix, Mac: ""}, true
|
||||
}
|
||||
|
||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
for _, duplicated := range addresses {
|
||||
if duplicated.NetIP == addr.NetIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
|
||||
66
client/system/network_addr.go
Normal file
66
client/system/network_addr.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build !ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func networkAddresses() ([]NetworkAddress, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
mac := iface.HardwareAddr.String()
|
||||
for _, address := range addrs {
|
||||
netAddr, ok := toNetworkAddress(address, mac)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDuplicated(netAddresses, netAddr) {
|
||||
continue
|
||||
}
|
||||
netAddresses = append(netAddresses, netAddr)
|
||||
}
|
||||
}
|
||||
return netAddresses, nil
|
||||
}
|
||||
|
||||
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
||||
ipNet, ok := address.(*net.IPNet)
|
||||
if !ok {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
if ipNet.IP.IsLoopback() {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||
if err != nil {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
return NetworkAddress{NetIP: prefix, Mac: mac}, true
|
||||
}
|
||||
|
||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
for _, duplicated := range addresses {
|
||||
if duplicated.NetIP == addr.NetIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -25,6 +25,12 @@ func (c *peekedConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// halfCloser matches connections that support shutting down the write
|
||||
// side while keeping the read side open (e.g. *net.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// CloseWrite delegates to the underlying connection if it supports
|
||||
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
|
||||
// as an interface hides the concrete type's CloseWrite method, making
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
|
||||
var errIdleTimeout = errors.New("idle timeout")
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
|
||||
// A zero value disables idle timeout checking.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
|
||||
// EOF to the remote by closing the write side while keeping the read
|
||||
// side open so the other direction can drain.
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Relay copies data bidirectionally between src and dst until both
|
||||
// sides are done or the context is canceled. When idleTimeout is
|
||||
// non-zero, each direction's read is deadline-guarded; if no data
|
||||
// flows within the timeout the connection is torn down. When one
|
||||
// direction finishes, it half-closes the write side of the
|
||||
// destination (if supported) to signal EOF, allowing the other
|
||||
// direction to drain gracefully before the full connection teardown.
|
||||
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errSrcToDst, errDstToSrc error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
|
||||
halfClose(dst)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
|
||||
halfClose(src)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
|
||||
logger.Debug("relay closed due to idle timeout")
|
||||
}
|
||||
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
|
||||
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
|
||||
}
|
||||
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
|
||||
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
|
||||
}
|
||||
|
||||
return srcToDst, dstToSrc
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
|
||||
// When idleTimeout > 0 it sets a read deadline on src before each
|
||||
// read and treats a timeout as an idle-triggered close.
|
||||
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
conn, ok := src.(net.Conn)
|
||||
if !ok {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return total, err
|
||||
}
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
n, err := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if netutil.IsTimeout(readErr) {
|
||||
return total, errIdleTimeout
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkedWrite writes buf to dst and returns the number of bytes written.
|
||||
// It guards against short writes and negative counts per io.Copy convention.
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
|
||||
}
|
||||
|
||||
// halfClose attempts to half-close the write side of the connection.
|
||||
// If the connection does not support half-close, this is a no-op.
|
||||
func halfClose(conn net.Conn) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
// Best-effort; the full close will follow shortly.
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,13 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
func testRelay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (int64, int64) {
|
||||
return netrelay.Relay(ctx, src, dst, netrelay.Options{IdleTimeout: idleTimeout, Logger: logger})
|
||||
}
|
||||
|
||||
func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
@@ -41,7 +46,7 @@ func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
s2d, d2s := testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
|
||||
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
|
||||
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
|
||||
@@ -58,7 +63,7 @@ func TestRelay_ContextCancellation(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -85,7 +90,7 @@ func TestRelay_OneSideClosed(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -129,7 +134,7 @@ func TestRelay_LargeTransfer(t *testing.T) {
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
s2d, _ := testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
|
||||
require.NoError(t, <-errCh)
|
||||
}
|
||||
@@ -182,7 +187,7 @@ func TestRelay_IdleTimeout(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
var s2d, d2s int64
|
||||
go func() {
|
||||
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
s2d, d2s = testRelay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// defaultDialTimeout is the fallback dial timeout when no per-route
|
||||
@@ -528,11 +529,14 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
|
||||
|
||||
idleTimeout := route.SessionIdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = DefaultIdleTimeout
|
||||
idleTimeout = netrelay.DefaultIdleTimeout
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
|
||||
s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{
|
||||
IdleTimeout: idleTimeout,
|
||||
Logger: entry,
|
||||
})
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
|
||||
227
util/netrelay/relay.go
Normal file
227
util/netrelay/relay.go
Normal file
@@ -0,0 +1,227 @@
|
||||
// Package netrelay provides a bidirectional byte-copy helper for TCP-like
|
||||
// connections with correct half-close propagation.
|
||||
//
|
||||
// When one direction reads EOF, the write side of the opposite connection is
|
||||
// half-closed (CloseWrite) so the peer sees FIN, then the second direction is
|
||||
// allowed to drain to its own EOF before both connections are fully closed.
|
||||
// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the
|
||||
// naive "cancel-both-on-first-EOF" pattern breaks.
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DebugLogger is the minimal interface netrelay uses to surface teardown
|
||||
// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method)
|
||||
// satisfy it, so callers can pass whichever they already use without an
|
||||
// adapter. Debugf is the only required method; callers with richer
|
||||
// loggers just expose this one shape here.
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
}
|
||||
|
||||
// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers
|
||||
// that want an idle timeout but have no specific preference can use this.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn, *gonet.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Options configures Relay behavior. The zero value is valid: no idle timeout,
|
||||
// no logging.
|
||||
type Options struct {
|
||||
// IdleTimeout tears down the session if no bytes flow in either
|
||||
// direction within this window. It is a connection-wide watchdog, so a
|
||||
// long unidirectional transfer on one side keeps the other side alive.
|
||||
// Zero disables idle tracking.
|
||||
IdleTimeout time.Duration
|
||||
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
|
||||
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
|
||||
// uspfilter's nblog.Logger, etc.).
|
||||
Logger DebugLogger
|
||||
}
|
||||
|
||||
// Relay copies bytes in both directions between a and b until both directions
|
||||
// EOF or ctx is canceled. On each direction's EOF it half-closes the
|
||||
// opposite conn's write side (best effort) so the peer sees FIN while the
|
||||
// other direction drains. Both conns are fully closed when Relay returns.
|
||||
//
|
||||
// a and b only need to implement io.ReadWriteCloser; connections that also
|
||||
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
|
||||
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
|
||||
// watchdog that tracks reads in either direction.
|
||||
//
|
||||
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
|
||||
// a.Write). Errors are logged via Options.Logger when set; they are not
|
||||
// returned because a relay always terminates on some kind of EOF/cancel.
|
||||
func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
}()
|
||||
|
||||
// Both sides must support CloseWrite to propagate half-close. If either
|
||||
// doesn't, a direction's EOF can't be signaled to the peer and the other
|
||||
// direction would block forever waiting for data; in that case we fall
|
||||
// back to the cancel-both-on-first-EOF behavior.
|
||||
_, aHC := a.(halfCloser)
|
||||
_, bHC := b.(halfCloser)
|
||||
halfCloseSupported := aHC && bHC
|
||||
|
||||
var (
|
||||
lastActivity atomic.Int64
|
||||
idleHit atomic.Bool
|
||||
)
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
if opts.IdleTimeout > 0 {
|
||||
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errAToB, errBToA error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
aToB, errAToB = copyTracked(b, a, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errAToB) {
|
||||
halfClose(b)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bToA, errBToA = copyTracked(a, b, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errBToA) {
|
||||
halfClose(a)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if opts.Logger != nil {
|
||||
if idleHit.Load() {
|
||||
opts.Logger.Debugf("relay closed due to idle timeout")
|
||||
}
|
||||
if errAToB != nil && !isExpectedCopyError(errAToB) {
|
||||
opts.Logger.Debugf("relay copy error (a→b): %v", errAToB)
|
||||
}
|
||||
if errBToA != nil && !isExpectedCopyError(errBToA) {
|
||||
opts.Logger.Debugf("relay copy error (b→a): %v", errBToA)
|
||||
}
|
||||
}
|
||||
|
||||
return aToB, bToA
|
||||
}
|
||||
|
||||
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
|
||||
// activity has been seen on either direction for idle. It exits as soon as
|
||||
// ctx is canceled so it doesn't outlive the relay.
|
||||
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
|
||||
tick := max(idle/2, 50*time.Millisecond)
|
||||
t := time.NewTicker(tick)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
last := time.Unix(0, lastActivity.Load())
|
||||
if time.Since(last) >= idle {
|
||||
idleHit.Store(true)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTracked copies from src to dst using a pooled buffer, updating
|
||||
// lastActivity on every successful read so a shared watchdog can enforce a
|
||||
// connection-wide idle timeout.
|
||||
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
n, werr := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if werr != nil {
|
||||
return total, werr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func halfClose(conn io.ReadWriteCloser) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
// isCleanEOF reports whether a copy terminated on a graceful end-of-stream.
|
||||
// Only in that case is it correct to propagate the EOF via CloseWrite on the
|
||||
// peer; any other error means the flow is broken and both directions should
|
||||
// tear down.
|
||||
func isCleanEOF(err error) bool {
|
||||
return err == nil || errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, syscall.ECONNRESET) ||
|
||||
errors.Is(err, syscall.EPIPE) ||
|
||||
errors.Is(err, syscall.ECONNABORTED)
|
||||
}
|
||||
221
util/netrelay/relay_test.go
Normal file
221
util/netrelay/relay_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// tcpPair returns two connected loopback TCP conns.
|
||||
func tcpPair(t *testing.T) (*net.TCPConn, *net.TCPConn) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
type result struct {
|
||||
c *net.TCPConn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
ch <- result{nil, err}
|
||||
return
|
||||
}
|
||||
ch <- result{c.(*net.TCPConn), nil}
|
||||
}()
|
||||
|
||||
dial, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
r := <-ch
|
||||
require.NoError(t, r.err)
|
||||
return dial.(*net.TCPConn), r.c
|
||||
}
|
||||
|
||||
// TestRelayHalfClose exercises the shutdown(SHUT_WR) scenario that the naive
|
||||
// cancel-both-on-first-EOF pattern breaks. Client A shuts down its write
|
||||
// side; B must still be able to write a full response and A must receive
|
||||
// all of it before its read returns EOF.
|
||||
func TestRelayHalfClose(t *testing.T) {
|
||||
// Real peer pairs for each side of the relay. We relay between relayA
|
||||
// and relayB. Peer A talks through relayA; peer B talks through relayB.
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
// Bound blocking reads/writes so a broken relay fails the test instead of
|
||||
// hanging the test process.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
require.NoError(t, peerA.SetDeadline(deadline))
|
||||
require.NoError(t, peerB.SetDeadline(deadline))
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Peer A sends a request, then half-closes its write side.
|
||||
req := []byte("request-payload")
|
||||
_, err := peerA.Write(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, peerA.CloseWrite())
|
||||
|
||||
// Peer B reads the request to EOF (FIN must have propagated).
|
||||
got, err := io.ReadAll(peerB)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, req, got)
|
||||
|
||||
// Peer B writes its response; peer A must receive all of it even though
|
||||
// peer A's write side is already closed.
|
||||
resp := make([]byte, 64*1024)
|
||||
for i := range resp {
|
||||
resp[i] = byte(i)
|
||||
}
|
||||
_, err = peerB.Write(resp)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, peerB.Close())
|
||||
|
||||
gotResp, err := io.ReadAll(peerA)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resp, gotResp)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayFullDuplex verifies bidirectional copy in the simple case.
|
||||
func TestRelayFullDuplex(t *testing.T) {
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
// Bound blocking reads/writes so a broken relay fails the test instead of
|
||||
// hanging the test process.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
require.NoError(t, peerA.SetDeadline(deadline))
|
||||
require.NoError(t, peerB.SetDeadline(deadline))
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
type result struct {
|
||||
got []byte
|
||||
err error
|
||||
}
|
||||
resA := make(chan result, 1)
|
||||
resB := make(chan result, 1)
|
||||
|
||||
msgAB := []byte("hello-from-a")
|
||||
msgBA := []byte("hello-from-b")
|
||||
|
||||
go func() {
|
||||
if _, err := peerA.Write(msgAB); err != nil {
|
||||
resA <- result{err: err}
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len(msgBA))
|
||||
_, err := io.ReadFull(peerA, buf)
|
||||
resA <- result{got: buf, err: err}
|
||||
_ = peerA.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := peerB.Write(msgBA); err != nil {
|
||||
resB <- result{err: err}
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len(msgAB))
|
||||
_, err := io.ReadFull(peerB, buf)
|
||||
resB <- result{got: buf, err: err}
|
||||
_ = peerB.Close()
|
||||
}()
|
||||
|
||||
a, b := <-resA, <-resB
|
||||
require.NoError(t, a.err)
|
||||
require.Equal(t, msgBA, a.got)
|
||||
require.NoError(t, b.err)
|
||||
require.Equal(t, msgAB, b.got)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayNoHalfCloseFallback ensures Relay terminates when the underlying
|
||||
// conns don't support CloseWrite (e.g. net.Pipe). Without the fallback to
|
||||
// cancel-both-on-first-EOF, the second direction would block forever.
|
||||
func TestRelayNoHalfCloseFallback(t *testing.T) {
|
||||
a1, a2 := net.Pipe()
|
||||
b1, b2 := net.Pipe()
|
||||
defer a1.Close()
|
||||
defer b1.Close()
|
||||
|
||||
ctx := t.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, a2, b2, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Close peer A's side; a2's Read will return EOF.
|
||||
require.NoError(t, a1.Close())
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not terminate when half-close is unsupported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayIdleTimeout ensures the idle watchdog tears down a silent flow.
|
||||
func TestRelayIdleTimeout(t *testing.T) {
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
const idle = 150 * time.Millisecond
|
||||
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{IdleTimeout: idle})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not close on idle")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
require.GreaterOrEqual(t, elapsed, idle,
|
||||
"relay must not close before the idle timeout elapses")
|
||||
require.Less(t, elapsed, idle+500*time.Millisecond,
|
||||
"relay should close shortly after the idle timeout")
|
||||
}
|
||||
Reference in New Issue
Block a user