mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 17:12:02 -04:00
Compare commits
13 Commits
improve-us
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5da05ecca6 | ||
|
|
801de8c68d | ||
|
|
a822a33240 | ||
|
|
57b23c5b25 | ||
|
|
1165058fad | ||
|
|
703353d354 | ||
|
|
2fb50aef6b | ||
|
|
eb3aa96257 | ||
|
|
064ec1c832 | ||
|
|
75e408f51c | ||
|
|
5a89e6621b | ||
|
|
06dfa9d4a5 | ||
|
|
45d9ee52c0 |
11
client/firewall/firewalld/firewalld.go
Normal file
11
client/firewall/firewalld/firewalld.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// Package firewalld integrates with the firewalld daemon so NetBird can place
|
||||
// its wg interface into firewalld's "trusted" zone. This is required because
|
||||
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
|
||||
// versions, which returns EPERM to any other process that tries to insert
|
||||
// rules into them. The workaround mirrors what Tailscale does: let firewalld
|
||||
// itself add the accept rules to its own chains by trusting the interface.
|
||||
package firewalld
|
||||
|
||||
// TrustedZone is the firewalld zone name used for interfaces whose traffic
|
||||
// should bypass firewalld filtering.
|
||||
const TrustedZone = "trusted"
|
||||
260
client/firewall/firewalld/firewalld_linux.go
Normal file
260
client/firewall/firewalld/firewalld_linux.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dbusDest = "org.fedoraproject.FirewallD1"
|
||||
dbusPath = "/org/fedoraproject/FirewallD1"
|
||||
dbusRootIface = "org.fedoraproject.FirewallD1"
|
||||
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
|
||||
|
||||
errZoneAlreadySet = "ZONE_ALREADY_SET"
|
||||
errAlreadyEnabled = "ALREADY_ENABLED"
|
||||
errUnknownIface = "UNKNOWN_INTERFACE"
|
||||
errNotEnabled = "NOT_ENABLED"
|
||||
|
||||
// callTimeout bounds each individual DBus or firewall-cmd invocation.
|
||||
// A fresh context is created for each call so a slow DBus probe can't
|
||||
// exhaust the deadline before the firewall-cmd fallback gets to run.
|
||||
callTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errDBusUnavailable = errors.New("firewalld dbus unavailable")
|
||||
|
||||
// trustLogOnce ensures the "added to trusted zone" message is logged at
|
||||
// Info level only for the first successful add per process; repeat adds
|
||||
// from other init paths are quieter.
|
||||
trustLogOnce sync.Once
|
||||
|
||||
parentCtxMu sync.RWMutex
|
||||
parentCtx context.Context = context.Background()
|
||||
)
|
||||
|
||||
// SetParentContext installs a parent context whose cancellation aborts any
|
||||
// in-flight TrustInterface call. It does not affect UntrustInterface, which
|
||||
// always uses a fresh Background-rooted timeout so cleanup can still run
|
||||
// during engine shutdown when the engine context is already cancelled.
|
||||
func SetParentContext(ctx context.Context) {
|
||||
parentCtxMu.Lock()
|
||||
parentCtx = ctx
|
||||
parentCtxMu.Unlock()
|
||||
}
|
||||
|
||||
func getParentContext() context.Context {
|
||||
parentCtxMu.RLock()
|
||||
defer parentCtxMu.RUnlock()
|
||||
return parentCtx
|
||||
}
|
||||
|
||||
// TrustInterface places iface into firewalld's trusted zone if firewalld is
|
||||
// running. It is idempotent and best-effort: errors are returned so callers
|
||||
// can log, but a non-running firewalld is not an error. Only the first
|
||||
// successful call per process logs at Info. Respects the parent context set
|
||||
// via SetParentContext so startup-time cancellation unblocks it.
|
||||
func TrustInterface(iface string) error {
|
||||
parent := getParentContext()
|
||||
if !isRunning(parent) {
|
||||
return nil
|
||||
}
|
||||
if err := addTrusted(parent, iface); err != nil {
|
||||
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
trustLogOnce.Do(func() {
|
||||
log.Infof("added %s to firewalld trusted zone", iface)
|
||||
})
|
||||
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
|
||||
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
|
||||
// during shutdown after the engine context has been cancelled.
|
||||
func UntrustInterface(iface string) error {
|
||||
if !isRunning(context.Background()) {
|
||||
return nil
|
||||
}
|
||||
if err := removeTrusted(context.Background(), iface); err != nil {
|
||||
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(parent, callTimeout)
|
||||
}
|
||||
|
||||
func isRunning(parent context.Context) bool {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
ok, err := isRunningDBus(ctx)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return ok
|
||||
}
|
||||
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return isRunningCLI(ctx)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := addDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return addCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func removeTrusted(parent context.Context, iface string) error {
|
||||
ctx, cancel := newCallContext(parent)
|
||||
err := removeDBus(ctx, iface)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if !errors.Is(err, errDBusUnavailable) {
|
||||
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
|
||||
}
|
||||
ctx, cancel = newCallContext(parent)
|
||||
defer cancel()
|
||||
return removeCLI(ctx, iface)
|
||||
}
|
||||
|
||||
func isRunningDBus(ctx context.Context) (bool, error) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
var zone string
|
||||
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
|
||||
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isRunningCLI(ctx context.Context) bool {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return false
|
||||
}
|
||||
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
|
||||
}
|
||||
|
||||
func addDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errAlreadyEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errZoneAlreadySet) {
|
||||
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
|
||||
if move.Err != nil {
|
||||
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld addInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func removeDBus(ctx context.Context, iface string) error {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
|
||||
}
|
||||
obj := conn.Object(dbusDest, dbusPath)
|
||||
|
||||
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
|
||||
if call.Err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
|
||||
}
|
||||
|
||||
func addCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
// --change-interface (no --permanent) binds the interface for the
|
||||
// current runtime only; we do not want membership to persist across
|
||||
// reboots because netbird re-asserts it on every startup.
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeCLI(ctx context.Context, iface string) error {
|
||||
if _, err := exec.LookPath("firewall-cmd"); err != nil {
|
||||
return fmt.Errorf("firewall-cmd not available: %w", err)
|
||||
}
|
||||
|
||||
out, err := exec.CommandContext(ctx,
|
||||
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
|
||||
).CombinedOutput()
|
||||
if err != nil {
|
||||
msg := strings.TrimSpace(string(out))
|
||||
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbusErrContains(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var de dbus.Error
|
||||
if errors.As(err, &de) {
|
||||
for _, b := range de.Body {
|
||||
if s, ok := b.(string); ok && strings.Contains(s, code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Contains(err.Error(), code)
|
||||
}
|
||||
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
49
client/firewall/firewalld/firewalld_linux_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
//go:build linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
func TestDBusErrContains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
code string
|
||||
want bool
|
||||
}{
|
||||
{"nil error", nil, errZoneAlreadySet, false},
|
||||
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
|
||||
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
|
||||
{
|
||||
"dbus.Error body match",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
|
||||
errZoneAlreadySet,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"dbus.Error body miss",
|
||||
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
|
||||
errAlreadyEnabled,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"dbus.Error non-string body falls back to Error()",
|
||||
dbus.Error{Name: "x", Body: []any{123}},
|
||||
"x",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := dbusErrContains(tc.err, tc.code)
|
||||
if got != tc.want {
|
||||
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
25
client/firewall/firewalld/firewalld_other.go
Normal file
25
client/firewall/firewalld/firewalld_other.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !linux
|
||||
|
||||
package firewalld
|
||||
|
||||
import "context"
|
||||
|
||||
// SetParentContext is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func SetParentContext(context.Context) {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
}
|
||||
|
||||
// TrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func TrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
|
||||
// runs on Linux.
|
||||
func UntrustInterface(string) error {
|
||||
// intentionally empty: firewalld is a Linux-only daemon
|
||||
return nil
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// Trust after all fatal init steps so a later failure doesn't leave the
|
||||
// interface in firewalld's trusted zone without a corresponding Close.
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
go func() {
|
||||
if err := stateManager.PersistState(context.Background()); err != nil {
|
||||
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
// stays persisted and the crash-recovery path retries firewalld cleanup.
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// attempt to delete state only if all other operations succeeded
|
||||
if merr == nil {
|
||||
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
|
||||
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
@@ -40,6 +41,8 @@ const (
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip all iptables-managed tables in the ip family
|
||||
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.Close(stateManager)
|
||||
}
|
||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
|
||||
if m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AllowNetbird()
|
||||
}
|
||||
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
// IFaceMapper defines subset methods of interface required for manager
|
||||
type IFaceMapper interface {
|
||||
Name() string
|
||||
SetFilter(device.PacketFilter) error
|
||||
Address() wgaddr.Address
|
||||
GetWGDevice() *wgdevice.Device
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTCPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "4")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 4, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 4,
|
||||
"TCP table must not exceed the configured cap")
|
||||
require.Greater(t, len(tracker.connections), 0,
|
||||
"some entries must remain after eviction")
|
||||
|
||||
// The most recently admitted flow must be present: eviction must make
|
||||
// room for new entries, not silently drop them.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
|
||||
"newest TCP flow must be admitted after eviction")
|
||||
// A pre-cap flow must have been evicted to fit the last one.
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
|
||||
"oldest TCP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
|
||||
t.Setenv(EnvTCPMaxEntries, "3")
|
||||
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
// Fill to cap with 3 live connections.
|
||||
for i := 0; i < 3; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
|
||||
}
|
||||
require.Len(t, tracker.connections, 3)
|
||||
|
||||
// Tombstone one by sending RST through IsValidInbound.
|
||||
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
|
||||
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
|
||||
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
|
||||
|
||||
// Another live connection forces eviction. The tombstone must go first.
|
||||
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
|
||||
|
||||
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
|
||||
require.False(t, tombstonedStillPresent,
|
||||
"tombstoned entry should be evicted before live entries")
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
|
||||
// Both live pre-cap entries must survive: eviction must prefer the
|
||||
// tombstone, not just satisfy the size bound by dropping any entry.
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
|
||||
"live entries must not be evicted while a tombstone exists")
|
||||
}
|
||||
|
||||
func TestUDPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvUDPMaxEntries, "5")
|
||||
|
||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 5, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
for i := 0; i < 12; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 5)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
|
||||
"newest UDP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
|
||||
"oldest UDP flow should have been evicted")
|
||||
}
|
||||
|
||||
func TestICMPCapEvicts(t *testing.T) {
|
||||
t.Setenv(EnvICMPMaxEntries, "3")
|
||||
|
||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
require.Equal(t, 3, tracker.maxEntries)
|
||||
|
||||
src := netip.MustParseAddr("100.64.0.1")
|
||||
dst := netip.MustParseAddr("100.64.0.2")
|
||||
|
||||
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
|
||||
for i := 0; i < 8; i++ {
|
||||
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
|
||||
}
|
||||
require.LessOrEqual(t, len(tracker.connections), 3)
|
||||
require.Greater(t, len(tracker.connections), 0)
|
||||
|
||||
require.Contains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
|
||||
"newest ICMP flow must be admitted after eviction")
|
||||
require.NotContains(t, tracker.connections,
|
||||
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
|
||||
"oldest ICMP flow should have been evicted")
|
||||
}
|
||||
@@ -3,61 +3,14 @@ package conntrack
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
// evictSampleSize bounds how many map entries we scan per eviction call.
|
||||
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
|
||||
// heuristic is good enough for a conntrack table that only overflows under
|
||||
// abuse.
|
||||
const evictSampleSize = 8
|
||||
|
||||
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
|
||||
// def on empty or invalid; logs a warning on invalid.
|
||||
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
d, err := time.ParseDuration(v)
|
||||
if err != nil {
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
}
|
||||
if d <= 0 {
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
|
||||
// invalid, or non-positive. Logs a warning on invalid input.
|
||||
func envInt(logger *nblog.Logger, name string, def int) int {
|
||||
v := os.Getenv(name)
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
n, err := strconv.Atoi(v)
|
||||
switch {
|
||||
case err != nil:
|
||||
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
|
||||
return def
|
||||
case n <= 0:
|
||||
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
|
||||
return def
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
FlowId uuid.UUID
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !ios && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on desktop/server platforms. These mirror
|
||||
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
|
||||
const (
|
||||
DefaultMaxTCPEntries = 65536
|
||||
DefaultMaxUDPEntries = 16384
|
||||
DefaultMaxICMPEntries = 2048
|
||||
)
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package conntrack
|
||||
|
||||
// Default per-tracker entry caps on mobile platforms. iOS network extensions
|
||||
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
|
||||
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
|
||||
// is ~200 B plus map overhead).
|
||||
const (
|
||||
DefaultMaxTCPEntries = 4096
|
||||
DefaultMaxUDPEntries = 2048
|
||||
DefaultMaxICMPEntries = 512
|
||||
)
|
||||
@@ -44,9 +44,6 @@ type ICMPConnTrack struct {
|
||||
ICMPCode uint8
|
||||
}
|
||||
|
||||
// EnvICMPMaxEntries caps the ICMP conntrack table size.
|
||||
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
|
||||
|
||||
// ICMPTracker manages ICMP connection states
|
||||
type ICMPTracker struct {
|
||||
logger *nblog.Logger
|
||||
@@ -55,7 +52,6 @@ type ICMPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -139,7 +135,6 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -226,9 +221,7 @@ func (t *ICMPTracker) track(
|
||||
|
||||
// non echo requests don't need tracking
|
||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||
return
|
||||
}
|
||||
@@ -247,15 +240,10 @@ func (t *ICMPTracker) track(
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
}
|
||||
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||
}
|
||||
|
||||
@@ -298,34 +286,6 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *ICMPTracker) evictOneLocked() {
|
||||
var candKey ICMPConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ICMPTracker) cleanup() {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
@@ -334,10 +294,8 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,27 +38,6 @@ const (
|
||||
TCPHandshakeTimeout = 60 * time.Second
|
||||
// TCPCleanupInterval is how often we check for stale connections
|
||||
TCPCleanupInterval = 5 * time.Minute
|
||||
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
|
||||
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
|
||||
FinWaitTimeout = 60 * time.Second
|
||||
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
|
||||
// holding CloseWait longer than this should bump the env var.
|
||||
CloseWaitTimeout = 60 * time.Second
|
||||
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
|
||||
LastAckTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Env vars to override per-state teardown timeouts. Values parsed by
|
||||
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
|
||||
// defaults above with a warning.
|
||||
const (
|
||||
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
|
||||
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
|
||||
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
|
||||
|
||||
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
|
||||
// (tombstones first) are evicted when the cap is reached.
|
||||
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
@@ -154,18 +133,14 @@ func (t *TCPConnTrack) SetTombstone() {
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
type TCPTracker struct {
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
finWaitTimeout time.Duration
|
||||
closeWaitTimeout time.Duration
|
||||
lastAckTimeout time.Duration
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
logger *nblog.Logger
|
||||
connections map[ConnKey]*TCPConnTrack
|
||||
mutex sync.RWMutex
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
@@ -180,17 +155,13 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
tracker := &TCPTracker{
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
|
||||
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
|
||||
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
|
||||
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
|
||||
flowLogger: flowLogger,
|
||||
logger: logger,
|
||||
connections: make(map[ConnKey]*TCPConnTrack),
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
go tracker.cleanupRoutine(ctx)
|
||||
@@ -238,12 +209,6 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
|
||||
// create spurious conntrack entries. Not mandated by RFC 9293 but a
|
||||
// common hardening (Linux netfilter/nftables rejects these too).
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
@@ -260,65 +225,20 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
conn.DNATOrigPort.Store(uint32(origPort))
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s TCP connection: %s", direction, key)
|
||||
}
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
|
||||
// iteration order is randomized), preferring a tombstone. If no tombstone
|
||||
// found in the sample, evicts the oldest among the sampled entries. O(1)
|
||||
// worst case — cheap enough to run on every insert at cap during abuse.
|
||||
func (t *TCPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
if c.IsTombstone() {
|
||||
delete(t.connections, k)
|
||||
return
|
||||
}
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
// TypeEnd is already emitted at the state transition to
|
||||
// TimeWait and when a connection is tombstoned. Only emit
|
||||
// here when we're reaping a still-active flow.
|
||||
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
@@ -336,19 +256,12 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject illegal flag combinations regardless of state. These never belong
|
||||
// to a legitimate flow and must not advance or tear down state.
|
||||
if !isValidFlagCombination(flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
if t.logger.Enabled(nblog.LevelWarn) {
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -357,208 +270,116 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags.
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
// Malformed flag combinations must not refresh lastSeen or drive state,
|
||||
// otherwise spoofed packets keep a dead flow alive past its timeout.
|
||||
if !isValidFlagCombination(flags) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
|
||||
// cannot apply the RFC 5961 in-window RST check, so we conservatively
|
||||
// reject RSTs that the spec would accept (TIME-WAIT with in-window
|
||||
// SEQ, SynSent from same direction as own SYN, etc.).
|
||||
t.handleRst(key, conn, currentState, packetDir)
|
||||
return
|
||||
}
|
||||
|
||||
newState := nextState(currentState, conn.Direction, packetDir, flags)
|
||||
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
|
||||
return
|
||||
}
|
||||
t.onTransition(key, conn, currentState, newState, packetDir)
|
||||
}
|
||||
|
||||
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
|
||||
// from the SYN direction are ignored; otherwise the flow is tombstoned.
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
|
||||
// TimeWait exists to absorb late segments; don't let a late RST
|
||||
// tombstone the entry and break same-4-tuple reuse.
|
||||
if currentState == TCPStateTimeWait {
|
||||
return
|
||||
}
|
||||
// A RST from the same direction as the SYN cannot be a legitimate
|
||||
// response and must not tear down a half-open connection.
|
||||
if currentState == TCPStateSynSent && packetDir == conn.Direction {
|
||||
return
|
||||
}
|
||||
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
return
|
||||
}
|
||||
conn.SetTombstone()
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
// stateTransition describes one state's transition logic. It receives the
|
||||
// packet's flags plus whether the packet direction matches the connection's
|
||||
// origin direction (same=true means same side as the SYN initiator). Return 0
|
||||
// for no transition.
|
||||
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
|
||||
|
||||
// stateTable maps each state to its transition function. Centralized here so
|
||||
// nextState stays trivial and each rule is easy to read in isolation.
|
||||
var stateTable = map[TCPState]stateTransition{
|
||||
TCPStateNew: transNew,
|
||||
TCPStateSynSent: transSynSent,
|
||||
TCPStateSynReceived: transSynReceived,
|
||||
TCPStateEstablished: transEstablished,
|
||||
TCPStateFinWait1: transFinWait1,
|
||||
TCPStateFinWait2: transFinWait2,
|
||||
TCPStateClosing: transClosing,
|
||||
TCPStateCloseWait: transCloseWait,
|
||||
TCPStateLastAck: transLastAck,
|
||||
}
|
||||
|
||||
// nextState returns the target TCP state for the given current state and
|
||||
// packet, or 0 if the packet does not trigger a transition.
|
||||
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
|
||||
fn, ok := stateTable[currentState]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return fn(flags, connDir, packetDir == connDir)
|
||||
}
|
||||
|
||||
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if connDir == nftypes.Egress {
|
||||
return TCPStateSynSent
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
return TCPStateSynReceived
|
||||
return
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if same {
|
||||
return TCPStateSynReceived // simultaneous open
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
|
||||
return TCPStateEstablished
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin == 0 {
|
||||
return 0
|
||||
}
|
||||
if same {
|
||||
return TCPStateFinWait1
|
||||
}
|
||||
return TCPStateCloseWait
|
||||
}
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
// transFinWait1 handles the active-close peer response. A FIN carrying our
|
||||
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
|
||||
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
|
||||
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
|
||||
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if same {
|
||||
return 0
|
||||
}
|
||||
if flags&TCPFin != 0 && flags&TCPAck != 0 {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
switch {
|
||||
case flags&TCPFin != 0:
|
||||
return TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
return TCPStateFinWait2
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
}
|
||||
|
||||
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
|
||||
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateFinWait1:
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
// transClosing completes a simultaneous close on the peer's ACK.
|
||||
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateTimeWait
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
|
||||
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPFin != 0 && same {
|
||||
return TCPStateLastAck
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
|
||||
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
|
||||
if flags&TCPAck != 0 && !same {
|
||||
return TCPStateClosed
|
||||
}
|
||||
return 0
|
||||
}
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
// onTransition handles logging and flow-event emission after a successful
|
||||
// state transition. TimeWait and Closed are terminal for flow accounting.
|
||||
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
|
||||
traceOn := t.logger.Enabled(nblog.LevelTrace)
|
||||
if traceOn {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
switch to {
|
||||
case TCPStateTimeWait:
|
||||
if traceOn {
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
if traceOn {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current
|
||||
// connection state. Caller must have already verified the flag combination is
|
||||
// legal via isValidFlagCombination.
|
||||
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
@@ -628,24 +449,15 @@ func (t *TCPTracker) cleanup() {
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
|
||||
timeout = t.finWaitTimeout
|
||||
case TCPStateCloseWait:
|
||||
timeout = t.closeWaitTimeout
|
||||
case TCPStateLastAck:
|
||||
timeout = t.lastAckTimeout
|
||||
default:
|
||||
// SynSent / SynReceived / New
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if currentState != TCPStateTimeWait {
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// RST hygiene tests: the tracker currently closes the flow on any RST that
|
||||
// matches the 4-tuple, regardless of direction or state. These tests cover
|
||||
// the minimum checks we want (no SEQ tracking).
|
||||
|
||||
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
|
||||
// cannot be a legitimate response. It must not close the connection.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState(),
|
||||
"RST in same direction as SYN must not close connection")
|
||||
require.False(t, conn.IsTombstone())
|
||||
}
|
||||
|
||||
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
// Drive to TIME-WAIT via active close.
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
|
||||
|
||||
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
|
||||
// exists to absorb late segments).
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState(),
|
||||
"RST in TIME-WAIT must not transition state")
|
||||
require.False(t, conn.IsTombstone(),
|
||||
"RST in TIME-WAIT must not tombstone the entry")
|
||||
}
|
||||
|
||||
func TestTCPIllegalFlagCombos(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
conn := tracker.connections[key]
|
||||
|
||||
// Illegal combos must be rejected and must not change state.
|
||||
combos := []struct {
|
||||
name string
|
||||
flags uint8
|
||||
}{
|
||||
{"SYN+RST", TCPSyn | TCPRst},
|
||||
{"FIN+RST", TCPFin | TCPRst},
|
||||
{"SYN+FIN", TCPSyn | TCPFin},
|
||||
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
|
||||
}
|
||||
|
||||
for _, c := range combos {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
before := conn.GetState()
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
|
||||
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
|
||||
require.Equal(t, before, conn.GetState(),
|
||||
"illegal flag combo must not change state")
|
||||
require.False(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,235 +0,0 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// These tests exercise cases where the TCP state machine currently advances
|
||||
// on retransmitted or wrong-direction segments and tears the flow down
|
||||
// prematurely. They are expected to fail until the direction checks are added.
|
||||
|
||||
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer sends FIN -> CloseWait (our app has not yet closed).
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
|
||||
// sent our FIN yet, so state must remain CloseWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid, "retransmitted peer FIN must still be accepted")
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState(),
|
||||
"retransmitted peer FIN must not advance CloseWait to LastAck")
|
||||
|
||||
// Our app finally closes -> LastAck.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Peer ACK closes.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// We initiate close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Stray retransmit of our own FIN (same direction as originator) must
|
||||
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState(),
|
||||
"own FIN retransmit must not advance FinWait2 to TimeWait")
|
||||
|
||||
// Peer FIN -> TimeWait.
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPLastAckDirectionCheck(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must NOT close.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState(),
|
||||
"own ACK retransmit in LastAck must not transition to Closed")
|
||||
|
||||
// Peer's ACK -> Closed.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
|
||||
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Our own ACK retransmit (same direction as originator) must not advance.
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState(),
|
||||
"own ACK in FinWait1 must not advance to FinWait2")
|
||||
}
|
||||
|
||||
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
|
||||
// Verify cleanup reaps entries in each teardown state at the configured
|
||||
// per-state timeout, not at the single handshake timeout.
|
||||
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
|
||||
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
|
||||
t.Setenv(EnvTCPLastAckTimeout, "30ms")
|
||||
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Drives a connection to the target state, forces its lastSeen well
|
||||
// beyond the configured timeout, runs cleanup, and asserts reaping.
|
||||
cases := []struct {
|
||||
name string
|
||||
// drive takes a fresh tracker and returns the conn key after
|
||||
// transitioning the flow into the intended teardown state.
|
||||
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
|
||||
}{
|
||||
{
|
||||
name: "FinWait1",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FinWait2",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWait",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "LastAck",
|
||||
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
|
||||
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
|
||||
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
|
||||
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
|
||||
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Use a unique source port per subtest so nothing aliases.
|
||||
port := uint16(12345)
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
|
||||
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
|
||||
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
port++
|
||||
key, wantState := c.drive(t, tracker, srcIP, port)
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, wantState, conn.GetState())
|
||||
|
||||
// Age the entry past the largest per-state timeout.
|
||||
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
|
||||
tracker.cleanup()
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "%s entry should be reaped", c.name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
|
||||
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
|
||||
// teardown states, which some stacks emit during close.
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Peer FIN -> CloseWait.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
|
||||
|
||||
// Peer pushes trailing data + FIN|PSH|ACK (legal).
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
|
||||
"FIN|PSH|ACK in CloseWait must be accepted")
|
||||
|
||||
// Bare ACK keepalive from peer in CloseWait must be accepted.
|
||||
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
|
||||
"bare ACK in CloseWait must be accepted")
|
||||
}
|
||||
@@ -17,9 +17,6 @@ const (
|
||||
DefaultUDPTimeout = 30 * time.Second
|
||||
// UDPCleanupInterval is how often we check for stale connections
|
||||
UDPCleanupInterval = 15 * time.Second
|
||||
|
||||
// EnvUDPMaxEntries caps the UDP conntrack table size.
|
||||
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
|
||||
)
|
||||
|
||||
// UDPConnTrack represents a UDP connection state
|
||||
@@ -37,7 +34,6 @@ type UDPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
mutex sync.RWMutex
|
||||
maxEntries int
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
@@ -55,7 +51,6 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
timeout: timeout,
|
||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -122,18 +117,13 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
|
||||
t.evictOneLocked()
|
||||
}
|
||||
t.connections[key] = conn
|
||||
t.mutex.Unlock()
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
if origPort != 0 {
|
||||
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
|
||||
} else {
|
||||
t.logger.Trace2("New %s UDP connection: %s", direction, key)
|
||||
}
|
||||
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||
}
|
||||
@@ -161,34 +151,6 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
return true
|
||||
}
|
||||
|
||||
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
|
||||
// Bounded sample: picks the oldest among up to evictSampleSize entries.
|
||||
func (t *UDPTracker) evictOneLocked() {
|
||||
var candKey ConnKey
|
||||
var candSeen int64
|
||||
haveCand := false
|
||||
sampled := 0
|
||||
|
||||
for k, c := range t.connections {
|
||||
seen := c.lastSeen.Load()
|
||||
if !haveCand || seen < candSeen {
|
||||
candKey = k
|
||||
candSeen = seen
|
||||
haveCand = true
|
||||
}
|
||||
sampled++
|
||||
if sampled >= evictSampleSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
if haveCand {
|
||||
if evicted := t.connections[candKey]; evicted != nil {
|
||||
t.sendEvent(nftypes.TypeEnd, evicted, nil)
|
||||
}
|
||||
delete(t.connections, candKey)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes stale connections
|
||||
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||
defer t.cleanupTicker.Stop()
|
||||
@@ -211,10 +173,8 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
if t.logger.Enabled(nblog.LevelTrace) {
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
}
|
||||
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -709,9 +709,7 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
|
||||
srcIP, dstIP := m.extractIPs(d)
|
||||
if !srcIP.IsValid() {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
}
|
||||
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -810,9 +808,7 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
}
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -935,10 +931,8 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||
|
||||
// TODO: pass fragments of routed packets to forwarder
|
||||
if fragment {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
}
|
||||
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -980,10 +974,8 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
pnum := getProtocolFromPacket(d)
|
||||
srcPort, dstPort := getPortsFromPacket(d)
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1033,10 +1025,8 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||
// Drop if routing is disabled
|
||||
if !m.routingEnabled.Load() {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
}
|
||||
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||
srcIP, dstIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1053,10 +1043,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
||||
if !pass {
|
||||
proto := getProtocolFromPacket(d)
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
}
|
||||
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
|
||||
|
||||
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||
FlowID: uuid.New(),
|
||||
@@ -1138,9 +1126,7 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
||||
// It returns true, true if the packet is a fragment and valid.
|
||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
}
|
||||
m.logger.Trace1("couldn't decode packet, err: %s", err)
|
||||
return false, false
|
||||
}
|
||||
|
||||
|
||||
@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type IFaceMock struct {
|
||||
NameFunc func() string
|
||||
SetFilterFunc func(device.PacketFilter) error
|
||||
AddressFunc func() wgaddr.Address
|
||||
GetWGDeviceFunc func() *wgdevice.Device
|
||||
GetDeviceFunc func() *device.FilteredDevice
|
||||
}
|
||||
|
||||
func (i *IFaceMock) Name() string {
|
||||
if i.NameFunc == nil {
|
||||
return "wgtest"
|
||||
}
|
||||
return i.NameFunc()
|
||||
}
|
||||
|
||||
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||
if i.GetWGDeviceFunc == nil {
|
||||
return nil
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
)
|
||||
|
||||
@@ -93,10 +92,8 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
@@ -119,10 +116,8 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
@@ -203,17 +198,13 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
}
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
@@ -13,9 +16,7 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// handleTCP is called by the TCP forwarder for new connections.
|
||||
@@ -37,9 +38,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
}
|
||||
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -62,22 +61,64 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||
inConn := gonet.NewTCPConn(&wq, ep)
|
||||
|
||||
success = true
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
|
||||
|
||||
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||
// netrelay.Relay copies bidirectionally with proper half-close propagation
|
||||
// and fully closes both conns before returning.
|
||||
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
|
||||
Logger: f.logger,
|
||||
})
|
||||
|
||||
// Close the netstack endpoint after both conns are drained.
|
||||
ep.Close()
|
||||
ctx, cancel := context.WithCancel(f.ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
// Close connections and endpoint.
|
||||
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: inConn close error: %v", err)
|
||||
}
|
||||
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||
f.logger.Debug1("forwarder: outConn close error: %v", err)
|
||||
}
|
||||
|
||||
ep.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var (
|
||||
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||
errInToOut error
|
||||
errOutToIn error
|
||||
)
|
||||
|
||||
go func() {
|
||||
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
|
||||
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||
cancel()
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errInToOut != nil {
|
||||
if !isClosedError(errInToOut) {
|
||||
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||
}
|
||||
}
|
||||
if errOutToIn != nil {
|
||||
if !isClosedError(errOutToIn) {
|
||||
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||
}
|
||||
}
|
||||
|
||||
var rxPackets, txPackets uint64
|
||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||
@@ -86,9 +127,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
||||
txPackets = tcpStats.SegmentsReceived.Value()
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
}
|
||||
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||
|
||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||
}
|
||||
|
||||
@@ -125,9 +125,7 @@ func (f *udpForwarder) cleanup() {
|
||||
delete(f.conns, idle.id)
|
||||
f.Unlock()
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,9 +144,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
_, exists := f.udpForwarder.conns[id]
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -210,9 +206,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
f.udpForwarder.Unlock()
|
||||
|
||||
success = true
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
}
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
@@ -271,9 +265,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
||||
txPackets = udpStats.PacketsReceived.Value()
|
||||
}
|
||||
|
||||
if f.logger.Enabled(nblog.LevelTrace) {
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
}
|
||||
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||
|
||||
f.udpForwarder.Lock()
|
||||
delete(f.udpForwarder.conns, id)
|
||||
|
||||
@@ -53,17 +53,16 @@ var levelStrings = map[Level]string{
|
||||
}
|
||||
|
||||
type logMessage struct {
|
||||
level Level
|
||||
argCount uint8
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
level Level
|
||||
format string
|
||||
arg1 any
|
||||
arg2 any
|
||||
arg3 any
|
||||
arg4 any
|
||||
arg5 any
|
||||
arg6 any
|
||||
arg7 any
|
||||
arg8 any
|
||||
}
|
||||
|
||||
// Logger is a high-performance, non-blocking logger
|
||||
@@ -108,13 +107,6 @@ func (l *Logger) SetLevel(level Level) {
|
||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||
}
|
||||
|
||||
// Enabled reports whether the given level is currently logged. Callers on the
|
||||
// hot path should guard log sites with this to avoid boxing arguments into
|
||||
// any when the level is off.
|
||||
func (l *Logger) Enabled(level Level) bool {
|
||||
return l.level.Load() >= uint32(level)
|
||||
}
|
||||
|
||||
func (l *Logger) Error(format string) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
@@ -163,7 +155,7 @@ func (l *Logger) Trace(format string) {
|
||||
func (l *Logger) Error1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -172,16 +164,7 @@ func (l *Logger) Error1(format string, arg1 any) {
|
||||
func (l *Logger) Error2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelError) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -190,7 +173,7 @@ func (l *Logger) Warn2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -199,7 +182,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -208,7 +191,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -217,7 +200,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
|
||||
func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -226,59 +209,16 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
|
||||
// to avoid allocating an args slice on the fast path when the arg count is
|
||||
// known (0-3). Args beyond 3 land on the general variadic path; callers on
|
||||
// the hot path should prefer DebugN for known counts.
|
||||
func (l *Logger) Debugf(format string, args ...any) {
|
||||
if l.level.Load() < uint32(LevelDebug) {
|
||||
return
|
||||
}
|
||||
switch len(args) {
|
||||
case 0:
|
||||
l.Debug(format)
|
||||
case 1:
|
||||
l.Debug1(format, args[0])
|
||||
case 2:
|
||||
l.Debug2(format, args[0], args[1])
|
||||
case 3:
|
||||
l.Debug3(format, args[0], args[1], args[2])
|
||||
default:
|
||||
l.sendVariadic(LevelDebug, format, args)
|
||||
}
|
||||
}
|
||||
|
||||
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
|
||||
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
|
||||
// beyond the 8-arg slot limit are dropped so callers don't produce silently
|
||||
// empty log lines via uint8 wraparound in argCount.
|
||||
func (l *Logger) sendVariadic(level Level, format string, args []any) {
|
||||
const maxArgs = 8
|
||||
n := len(args)
|
||||
if n > maxArgs {
|
||||
n = maxArgs
|
||||
}
|
||||
msg := logMessage{level: level, argCount: uint8(n), format: format}
|
||||
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
|
||||
for i := 0; i < n; i++ {
|
||||
*slots[i] = args[i]
|
||||
}
|
||||
select {
|
||||
case l.msgChannel <- msg:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Trace1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -287,7 +227,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
|
||||
func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -296,7 +236,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
|
||||
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -305,7 +245,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
|
||||
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -314,7 +254,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -323,7 +263,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
|
||||
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -333,7 +273,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
|
||||
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
|
||||
if l.level.Load() >= uint32(LevelTrace) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -346,8 +286,35 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
|
||||
*buf = append(*buf, levelStrings[msg.level]...)
|
||||
*buf = append(*buf, ' ')
|
||||
|
||||
// Count non-nil arguments for switch
|
||||
argCount := 0
|
||||
if msg.arg1 != nil {
|
||||
argCount++
|
||||
if msg.arg2 != nil {
|
||||
argCount++
|
||||
if msg.arg3 != nil {
|
||||
argCount++
|
||||
if msg.arg4 != nil {
|
||||
argCount++
|
||||
if msg.arg5 != nil {
|
||||
argCount++
|
||||
if msg.arg6 != nil {
|
||||
argCount++
|
||||
if msg.arg7 != nil {
|
||||
argCount++
|
||||
if msg.arg8 != nil {
|
||||
argCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var formatted string
|
||||
switch msg.argCount {
|
||||
switch argCount {
|
||||
case 0:
|
||||
formatted = msg.format
|
||||
case 1:
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
)
|
||||
|
||||
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||
@@ -243,15 +242,11 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite packet destination: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
}
|
||||
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -269,15 +264,11 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||
}
|
||||
|
||||
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite packet source: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if m.logger.Enabled(nblog.LevelTrace) {
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
}
|
||||
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -530,9 +521,7 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
|
||||
}
|
||||
|
||||
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
|
||||
if m.logger.Enabled(nblog.LevelError) {
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
}
|
||||
m.logger.Error1("failed to rewrite port: %v", err)
|
||||
return false
|
||||
}
|
||||
d.dnatOrigPort = rule.origPort
|
||||
|
||||
@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||
ipv6Count++
|
||||
}
|
||||
|
||||
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
|
||||
// routing-correctness checks above are the real assertions; the counts
|
||||
// are a sanity bound to catch a totally silent path.
|
||||
minDelivered := packetsPerFamily * 80 / 100
|
||||
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
|
||||
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
|
||||
}
|
||||
|
||||
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||
|
||||
@@ -3,10 +3,12 @@ package debug
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
|
||||
t.Skip("Skipping upload test on docker ci")
|
||||
}
|
||||
testDir := t.TempDir()
|
||||
testURL := "http://localhost:8080"
|
||||
addr := reserveLoopbackPort(t)
|
||||
testURL := "http://" + addr
|
||||
t.Setenv("SERVER_URL", testURL)
|
||||
t.Setenv("SERVER_ADDRESS", addr)
|
||||
t.Setenv("STORE_DIR", testDir)
|
||||
srv := server.NewServer()
|
||||
go func() {
|
||||
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
|
||||
t.Errorf("Failed to stop server: %v", err)
|
||||
}
|
||||
})
|
||||
waitForServer(t, addr)
|
||||
|
||||
file := filepath.Join(t.TempDir(), "tmpfile")
|
||||
fileContent := []byte("test file content")
|
||||
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, fileContent, createdFileContent)
|
||||
}
|
||||
|
||||
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
|
||||
// address, then releases it so the server under test can rebind. The close/
|
||||
// rebind window is racy in theory; on loopback with a kernel-assigned port
|
||||
// it's essentially never contended in practice.
|
||||
func reserveLoopbackPort(t *testing.T) string {
|
||||
t.Helper()
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
addr := l.Addr().String()
|
||||
require.NoError(t, l.Close())
|
||||
return addr
|
||||
}
|
||||
|
||||
func waitForServer(t *testing.T, addr string) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("server did not start listening on %s in time", addr)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
const (
|
||||
defaultResolvConfPath = "/etc/resolv.conf"
|
||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
||||
)
|
||||
|
||||
type resolvConf struct {
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
|
||||
}
|
||||
|
||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
c.dispatch(w, r, math.MaxInt)
|
||||
}
|
||||
|
||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
||||
if len(r.Question) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Try handlers in priority order
|
||||
for _, entry := range handlers {
|
||||
if entry.Priority > maxPriority {
|
||||
continue
|
||||
}
|
||||
if !c.isHandlerMatch(qname, entry) {
|
||||
continue
|
||||
}
|
||||
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
||||
cw.response.Len(), meta, time.Since(startTime))
|
||||
}
|
||||
|
||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
||||
// (bounded by the invoked handler's internal timeout).
|
||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
base := &internalResponseWriter{}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
c.dispatch(base, r, maxPriority)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
||||
select {
|
||||
case <-done:
|
||||
default:
|
||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
||||
}
|
||||
return base.response, nil
|
||||
}
|
||||
|
||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
||||
// priority ≤ maxPriority.
|
||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
for _, h := range c.handlers {
|
||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
switch {
|
||||
case entry.Pattern == ".":
|
||||
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
||||
type internalResponseWriter struct {
|
||||
response *dns.Msg
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||
|
||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
||||
// still surface their answer to ResolveInternal.
|
||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
||||
msg := new(dns.Msg)
|
||||
if err := msg.Unpack(p); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.response = msg
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *internalResponseWriter) Close() error { return nil }
|
||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
||||
|
||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
||||
// no-op: in-process queries carry no TSIG state.
|
||||
}
|
||||
|
||||
// Hijack is part of dns.ResponseWriter.
|
||||
func (w *internalResponseWriter) Hijack() {
|
||||
// no-op: in-process queries have no underlying connection to hand off.
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package dns_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
||||
// which handler ResolveInternal dispatches to.
|
||||
type answeringHandler struct {
|
||||
name string
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *answeringHandler) String() string { return h.name }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
||||
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
assert.Equal(t, 1, len(resp.Answer))
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
||||
}
|
||||
|
||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
||||
type rawWriteHandler struct {
|
||||
ip string
|
||||
}
|
||||
|
||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Answer = []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP(h.ip).To4(),
|
||||
}}
|
||||
packed, err := resp.Pack()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, _ = w.Write(packed)
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Len(t, resp.Answer, 1)
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
||||
}
|
||||
|
||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
||||
type hangingHandler struct {
|
||||
block chan struct{}
|
||||
}
|
||||
|
||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
<-h.block
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
_ = w.WriteMsg(resp)
|
||||
}
|
||||
|
||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
||||
|
||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &hangingHandler{block: make(chan struct{})}
|
||||
defer close(h.block)
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
||||
}
|
||||
|
||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
||||
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
||||
|
||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
||||
|
||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
||||
|
||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
|
||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
||||
|
||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
||||
}
|
||||
|
||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface string) (hostManager, error) {
|
||||
osManager, err := getOSDNSManagerType()
|
||||
osManager, reason, err := getOSDNSManagerType()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("System DNS manager discovered: %s", osManager)
|
||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||
if err != nil {
|
||||
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
||||
}
|
||||
}
|
||||
|
||||
func getOSDNSManagerType() (osManagerType, error) {
|
||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
||||
resolved := isSystemdResolvedRunning()
|
||||
nss := isLibnssResolveUsed()
|
||||
stub := checkStub()
|
||||
|
||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
||||
// that go through nss-resolve, and in foreign mode they can loop back
|
||||
// through resolved as an upstream.
|
||||
if resolved && (nss || stub) {
|
||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
||||
}
|
||||
|
||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
if reason != "" {
|
||||
return mgr, reason, nil
|
||||
}
|
||||
|
||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
||||
if len(rejected) > 0 {
|
||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
||||
}
|
||||
return fileManager, fallback, nil
|
||||
}
|
||||
|
||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
||||
// matching manager. If reason is empty the caller should pick file mode and
|
||||
// use rejected for diagnostics.
|
||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
||||
file, err := os.Open(defaultResolvConfPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||
if cerr := file.Close(); cerr != nil {
|
||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
var rejected []string
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
text := scanner.Text()
|
||||
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
|
||||
continue
|
||||
}
|
||||
if text[0] != '#' {
|
||||
return fileManager, nil
|
||||
break
|
||||
}
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, nil
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, nil
|
||||
}
|
||||
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||
if checkStub() {
|
||||
return systemdManager, nil
|
||||
} else {
|
||||
return fileManager, nil
|
||||
}
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, nil
|
||||
}
|
||||
|
||||
return resolvConfManager, nil
|
||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
||||
return mgr, reason, nil, nil
|
||||
} else if rej != "" {
|
||||
rejected = append(rejected, rej)
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
return 0, fmt.Errorf("scan: %w", err)
|
||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
|
||||
return fileManager, nil
|
||||
return 0, "", rejected, nil
|
||||
}
|
||||
|
||||
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
||||
}
|
||||
if strings.Contains(text, "NetworkManager") {
|
||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
||||
}
|
||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
||||
}
|
||||
if strings.Contains(text, "resolvconf") {
|
||||
if isSystemdResolveConfMode() {
|
||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
||||
}
|
||||
return resolvConfManager, "resolvconf header detected", ""
|
||||
}
|
||||
return 0, "", ""
|
||||
}
|
||||
|
||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
||||
// into file mode while resolved is active.
|
||||
func checkStub() bool {
|
||||
rConf, err := parseDefaultResolvConf()
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse resolv conf: %s", err)
|
||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -139,3 +178,36 @@ func checkStub() bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
||||
func isLibnssResolveUsed() bool {
|
||||
bs, err := os.ReadFile(nsswitchConfPath)
|
||||
if err != nil {
|
||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
||||
return false
|
||||
}
|
||||
return parseNsswitchResolveAhead(bs)
|
||||
}
|
||||
|
||||
func parseNsswitchResolveAhead(data []byte) bool {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
||||
line = line[:i]
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
||||
continue
|
||||
}
|
||||
for _, module := range fields[1:] {
|
||||
switch module {
|
||||
case "dns":
|
||||
return false
|
||||
case "resolve":
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
76
client/internal/dns/host_unix_test.go
Normal file
76
client/internal/dns/host_unix_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package dns
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "resolve before dns with action token",
|
||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "dns before resolve",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "debian default with only dns",
|
||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "neither resolve nor dns",
|
||||
in: "hosts: files myhostname\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no hosts line",
|
||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty",
|
||||
in: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "comments and blank lines ignored",
|
||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "trailing inline comment",
|
||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "hosts token must be the first field",
|
||||
in: " hosts: resolve dns\n",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "other db line mentioning resolve is ignored",
|
||||
in: "networks: resolve\nhosts: dns\n",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "only resolve, no dns",
|
||||
in: "hosts: files resolve\n",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,40 +2,83 @@ package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
const dnsTimeout = 5 * time.Second
|
||||
const (
|
||||
dnsTimeout = 5 * time.Second
|
||||
defaultTTL = 300 * time.Second
|
||||
refreshBackoff = 30 * time.Second
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains
|
||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
||||
)
|
||||
|
||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
||||
// system resolver.
|
||||
type ChainResolver interface {
|
||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
||||
}
|
||||
|
||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
||||
// records and cachedAt are set at construction and treated as immutable;
|
||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
||||
// Resolver.mutex.
|
||||
type cachedRecord struct {
|
||||
records []dns.RR
|
||||
cachedAt time.Time
|
||||
lastFailedRefresh time.Time
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question][]dns.RR
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type ipsResponse struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
// the OS resolver routed the recursive query back to us (loop). Only
|
||||
// the OS path arms this so chain-path refreshes don't produce false
|
||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
||||
// throttle the warning log.
|
||||
refreshing map[dns.Question]*atomic.Bool
|
||||
|
||||
cacheTTL time.Duration
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question][]dns.RR),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
|
||||
return "MgmtCacheResolver"
|
||||
}
|
||||
|
||||
// ServeDNS implements dns.Handler interface.
|
||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
||||
// maxPriority caps which handlers may answer refresh queries (typically
|
||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
||||
// mgmt/route/local handlers are skipped).
|
||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
||||
m.mutex.Lock()
|
||||
m.chain = chain
|
||||
m.chainMaxPriority = maxPriority
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
m.continueToNext(w, r)
|
||||
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
records, found := m.records[question]
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
||||
shouldRefresh = stale && !inBackoff
|
||||
}
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
||||
question.Name)
|
||||
}
|
||||
|
||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
|
||||
resp.Answer = append(resp.Answer, records...)
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// AddDomain manually adds a domain to cache by resolving it.
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||
if err != nil {
|
||||
return err
|
||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
||||
|
||||
if errA != nil && errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
var aRecords, aaaaRecords []dns.RR
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
rr := &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
A: ip.AsSlice(),
|
||||
}
|
||||
aRecords = append(aRecords, rr)
|
||||
} else if ip.Is6() {
|
||||
rr := &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: dnsName,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 300,
|
||||
},
|
||||
AAAA: ip.AsSlice(),
|
||||
}
|
||||
aaaaRecords = append(aaaaRecords, rr)
|
||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
||||
if err := errors.Join(errA, errAAAA); err != nil {
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
||||
}
|
||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if len(aRecords) > 0 {
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aQuestion] = aRecords
|
||||
}
|
||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
||||
|
||||
if len(aaaaRecords) > 0 {
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
m.records[aaaaQuestion] = aaaaRecords
|
||||
}
|
||||
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||
resultChan := make(chan *ipsResponse, 1)
|
||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
||||
// untouched on error. Caller holds m.mutex.
|
||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
switch {
|
||||
case len(records) > 0:
|
||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
||||
case err == nil:
|
||||
delete(m.records, q)
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||
resultChan <- &ipsResponse{
|
||||
err: err,
|
||||
ips: ips,
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails == 0 || fails > 1 {
|
||||
logf = log.Debugf
|
||||
}
|
||||
}()
|
||||
|
||||
var resp *ipsResponse
|
||||
|
||||
select {
|
||||
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case resp = <-resultChan:
|
||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.err != nil {
|
||||
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
return resp.ips, nil
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
||||
m.mutex.Unlock()
|
||||
|
||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
m.refreshing[question] = &atomic.Bool{}
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
m.mutex.Lock()
|
||||
delete(m.refreshing, question)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
c.consecFailures++
|
||||
return c.consecFailures
|
||||
}
|
||||
|
||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
||||
return
|
||||
}
|
||||
|
||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
||||
// spot the OS resolver routing the recursive query back to us.
|
||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// TODO: drop once every supported OS registers a fallback resolver.
|
||||
m.markRefreshing(q)
|
||||
defer m.clearRefreshing(q)
|
||||
|
||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
||||
}
|
||||
|
||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
msg := &dns.Msg{}
|
||||
msg.SetQuestion(dnsName, qtype)
|
||||
msg.RecursionDesired = true
|
||||
|
||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
||||
}
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
||||
}
|
||||
|
||||
ttl := uint32(m.cacheTTL.Seconds())
|
||||
owners := cnameOwners(dnsName, resp.Answer)
|
||||
var filtered []dns.RR
|
||||
for _, rr := range resp.Answer {
|
||||
h := rr.Header()
|
||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
||||
continue
|
||||
}
|
||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
||||
continue
|
||||
}
|
||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
||||
filtered = append(filtered, cp)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
||||
// returns (nil, nil).
|
||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
||||
network := resutil.NetworkForQtype(qtype)
|
||||
if network == "" {
|
||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
||||
}
|
||||
|
||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
||||
|
||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
||||
if result.Rcode == dns.RcodeSuccess {
|
||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
||||
}
|
||||
|
||||
if result.Err != nil {
|
||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
}
|
||||
|
||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
aQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aQuestion)
|
||||
|
||||
aaaaQuestion := dns.Question{
|
||||
Name: dnsName,
|
||||
Qtype: dns.TypeAAAA,
|
||||
Qclass: dns.ClassINET,
|
||||
}
|
||||
delete(m.records, aaaaQuestion)
|
||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
||||
delete(m.records, qA)
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
|
||||
return domains
|
||||
}
|
||||
|
||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
||||
// A/AAAA records return nil.
|
||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
return &cp
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Name = owner
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
return &cp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
||||
// stamping ttl so the response shares no memory with the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
||||
out = append(out, cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
||||
owners := map[string]bool{dnsName: true}
|
||||
for {
|
||||
added := false
|
||||
for _, rr := range answer {
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
||||
if !owners[name] {
|
||||
continue
|
||||
}
|
||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
||||
if !owners[target] {
|
||||
owners[target] = true
|
||||
added = true
|
||||
}
|
||||
}
|
||||
if !added {
|
||||
return owners
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
||||
func resolveCacheTTL() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
return d
|
||||
}
|
||||
}
|
||||
return defaultTTL
|
||||
}
|
||||
|
||||
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
408
client/internal/dns/mgmt/mgmt_refresh_test.go
Normal file
@@ -0,0 +1,408 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
}
|
||||
|
||||
func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.hasRoot
|
||||
}
|
||||
|
||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
||||
f.mu.Lock()
|
||||
q := msg.Question[0]
|
||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
if onLookup != nil {
|
||||
onLookup()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(msg)
|
||||
resp.Answer = answers
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
key := name + "|" + dns.TypeToString[qtype]
|
||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
||||
case dns.TypeAAAA:
|
||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
||||
}
|
||||
|
||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(d)
|
||||
for time.Now().Before(deadline) {
|
||||
if fn() {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("condition not met within %s", d)
|
||||
}
|
||||
|
||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
||||
t.Helper()
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion(name, dns.TypeA)
|
||||
w := &test.MockResponseWriter{}
|
||||
r.ServeDNS(w, msg)
|
||||
return w.GetLastResponse()
|
||||
}
|
||||
|
||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
||||
t.Helper()
|
||||
require.NotNil(t, resp)
|
||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
||||
a, ok := resp.Answer[0].(*dns.A)
|
||||
require.True(t, ok, "expected A record")
|
||||
return a.A.String()
|
||||
}
|
||||
|
||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
||||
// trigger a background refresh, the longer one must not. Proves that the
|
||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
||||
|
||||
newRec := func() *cachedRecord {
|
||||
return &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: cachedAt,
|
||||
}
|
||||
}
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
r.records[q] = newRec()
|
||||
|
||||
resp := queryA(t, r, q.Name)
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(), // fresh
|
||||
}
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
||||
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
||||
}
|
||||
|
||||
// First query: serves stale immediately.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
||||
})
|
||||
|
||||
// Next query should now return the refreshed IP.
|
||||
waitFor(t, time.Second, func() bool {
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
var inflight atomic.Int32
|
||||
var maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
cur := inflight.Add(1)
|
||||
defer inflight.Add(-1)
|
||||
for {
|
||||
prev := maxInflight.Load()
|
||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
||||
}
|
||||
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
waitFor(t, 2*time.Second, func() bool {
|
||||
return inflight.Load() == 0
|
||||
})
|
||||
|
||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
||||
}
|
||||
|
||||
// First stale hit triggers a refresh attempt that fails.
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
||||
|
||||
waitFor(t, time.Second, func() bool {
|
||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
||||
})
|
||||
waitFor(t, time.Second, func() bool {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
c, ok := r.records[q]
|
||||
return ok && !c.lastFailedRefresh.IsZero()
|
||||
})
|
||||
|
||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
||||
for i := 0; i < 10; i++ {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
// With hasRoot=false the chain must not be consulted. Use a short
|
||||
// deadline so the OS fallback returns quickly without waiting on a
|
||||
// real network call in CI.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
||||
|
||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
||||
"chain must not be used when no root handler is registered at the bound priority")
|
||||
}
|
||||
|
||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
// Simulate an inflight refresh.
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
require.NotNil(t, inflight)
|
||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
r.markRefreshing(q)
|
||||
defer r.clearRefreshing(q)
|
||||
|
||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
||||
for range 5 {
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
}
|
||||
|
||||
r.mutex.RLock()
|
||||
inflight := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.True(t, inflight.Load())
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
records: []dns.RR{&dns.A{
|
||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: net.ParseIP("10.0.0.1").To4(),
|
||||
}},
|
||||
cachedAt: time.Now(),
|
||||
}
|
||||
|
||||
queryA(t, r, "mgmt.example.com.")
|
||||
|
||||
r.mutex.RLock()
|
||||
_, ok := r.refreshing[q]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
||||
|
||||
resp := queryA(t, r, "mgmt.example.com.")
|
||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
|
||||
assert.False(t, resolver.MatchSubdomains())
|
||||
}
|
||||
|
||||
func TestResolveCacheTTL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
want time.Duration
|
||||
}{
|
||||
{"unset falls back to default", "", defaultTTL},
|
||||
{"valid duration", "45s", 45 * time.Second},
|
||||
{"valid minutes", "2m", 2 * time.Minute},
|
||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
||||
{"zero falls back to default", "0s", defaultTTL},
|
||||
{"negative falls back to default", "-5s", defaultTTL},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
||||
got := resolveCacheTTL()
|
||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
func TestResolver_ResponseTTL(t *testing.T) {
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
cacheTTL time.Duration
|
||||
cachedAt time.Time
|
||||
wantMin uint32
|
||||
wantMax uint32
|
||||
}{
|
||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
||||
got := r.responseTTL(tc.cachedAt)
|
||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -212,6 +212,7 @@ func newDefaultServer(
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
@@ -570,7 +571,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.connMgr.Start(e.ctx)
|
||||
|
||||
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
|
||||
e.srWatcher.Start()
|
||||
e.srWatcher.Start(peer.IsForceRelayed())
|
||||
|
||||
e.receiveSignalEvents()
|
||||
e.receiveManagementEvents()
|
||||
@@ -604,6 +605,8 @@ func (e *Engine) createFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
firewalld.SetParentContext(e.ctx)
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil {
|
||||
|
||||
@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
forceRelay := IsForceRelayed()
|
||||
if !forceRelay {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
if !forceRelay {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.Close()
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
err := conn.wgProxyRelay.CloseConn()
|
||||
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
|
||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||
conn.dumpState.RemoteCandidate()
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
if conn.workerICE != nil {
|
||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||
}
|
||||
}
|
||||
|
||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
|
||||
return StatusConnecting
|
||||
}
|
||||
|
||||
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
// would be better to protect this with a mutex, but it could cause deadlock with Close function
|
||||
|
||||
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
|
||||
//
|
||||
// The result is a tri-state:
|
||||
// - ConnStatusConnected: all available transports are up
|
||||
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
|
||||
// - ConnStatusDisconnected: no working transport
|
||||
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
defer func() {
|
||||
if !connected {
|
||||
if status == guard.ConnStatusDisconnected {
|
||||
conn.logTraceConnState()
|
||||
}
|
||||
}()
|
||||
|
||||
// For JS platform: only relay connection is supported
|
||||
if runtime.GOOS == "js" {
|
||||
return conn.statusRelay.Get() == worker.StatusConnected
|
||||
iceWorkerCreated := conn.workerICE != nil
|
||||
|
||||
var iceInProgress bool
|
||||
if iceWorkerCreated {
|
||||
iceInProgress = conn.workerICE.InProgress()
|
||||
}
|
||||
|
||||
// For non-JS platforms: check ICE connection status
|
||||
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
|
||||
return false
|
||||
}
|
||||
|
||||
// If relay is supported with peer, it must also be connected
|
||||
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
|
||||
if conn.statusRelay.Get() == worker.StatusDisconnected {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
return evalConnStatus(connStatusInputs{
|
||||
forceRelay: IsForceRelayed(),
|
||||
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
|
||||
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
|
||||
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
|
||||
iceWorkerCreated: iceWorkerCreated,
|
||||
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
|
||||
iceInProgress: iceInProgress,
|
||||
})
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
|
||||
// "Relay up and needed" — the peer uses relay and the transport is connected.
|
||||
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
|
||||
|
||||
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
|
||||
if in.forceRelay {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// Remote peer doesn't support ICE, or we haven't created the worker yet:
|
||||
// relay is the only possible transport.
|
||||
if !in.remoteSupportsICE || !in.iceWorkerCreated {
|
||||
return boolToConnStatus(relayUsedAndUp)
|
||||
}
|
||||
|
||||
// ICE counts as "up" when the status is anything other than Disconnected, OR
|
||||
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
|
||||
iceUp := in.iceStatusConnecting || in.iceInProgress
|
||||
|
||||
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
|
||||
relayOK := !in.peerUsesRelay || in.relayConnected
|
||||
|
||||
switch {
|
||||
case iceUp && relayOK:
|
||||
return guard.ConnStatusConnected
|
||||
case relayUsedAndUp:
|
||||
// Relay is up but ICE is down — partially connected.
|
||||
return guard.ConnStatusPartiallyConnected
|
||||
default:
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
}
|
||||
|
||||
func boolToConnStatus(connected bool) guard.ConnStatus {
|
||||
if connected {
|
||||
return guard.ConnStatusConnected
|
||||
}
|
||||
return guard.ConnStatusDisconnected
|
||||
}
|
||||
|
||||
@@ -13,6 +13,20 @@ const (
|
||||
StatusConnected
|
||||
)
|
||||
|
||||
// connStatusInputs is the primitive-valued snapshot of the state that drives the
|
||||
// tri-state connection classification. Extracted so the decision logic can be unit-tested
|
||||
// without constructing full Worker/Handshaker objects.
|
||||
type connStatusInputs struct {
|
||||
forceRelay bool // NB_FORCE_RELAY or JS/WASM
|
||||
peerUsesRelay bool // remote peer advertises relay support AND local has relay
|
||||
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
|
||||
remoteSupportsICE bool // remote peer sent ICE credentials
|
||||
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
|
||||
iceStatusConnecting bool // statusICE is anything other than Disconnected
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
201
client/internal/peer/conn_status_eval_test.go
Normal file
201
client/internal/peer/conn_status_eval_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||
)
|
||||
|
||||
func TestEvalConnStatus_ForceRelay(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "force relay, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "force relay, peer does NOT use relay - disconnected forever",
|
||||
in: connStatusInputs{
|
||||
forceRelay: true,
|
||||
peerUsesRelay: false,
|
||||
relayConnected: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in connStatusInputs
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer uses relay, relay down",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE worker not yet created, relay up",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: true,
|
||||
relayConnected: true,
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: false,
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "remote does not support ICE, peer does not use relay",
|
||||
in: connStatusInputs{
|
||||
peerUsesRelay: false,
|
||||
relayConnected: false,
|
||||
remoteSupportsICE: false,
|
||||
iceWorkerCreated: true,
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := evalConnStatus(tc.in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
|
||||
base := connStatusInputs{
|
||||
remoteSupportsICE: true,
|
||||
iceWorkerCreated: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mutator func(*connStatusInputs)
|
||||
want guard.ConnStatus
|
||||
}{
|
||||
{
|
||||
name: "ICE connected, relay connected, peer uses relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE connected, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE InProgress only, peer does NOT use relay",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = true
|
||||
},
|
||||
want: guard.ConnStatusConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up, peer uses relay -> partial",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = true
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusPartiallyConnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, peer does NOT use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = true
|
||||
in.relayConnected = false
|
||||
in.iceStatusConnecting = true
|
||||
},
|
||||
// relayOK = false (peer uses relay but it's down), iceUp = true
|
||||
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
|
||||
// falls into default: Disconnected.
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
{
|
||||
name: "ICE down, relay up but peer does not use relay -> disconnected",
|
||||
mutator: func(in *connStatusInputs) {
|
||||
in.peerUsesRelay = false
|
||||
in.relayConnected = true // not actually used since peer doesn't rely on it
|
||||
in.iceStatusConnecting = false
|
||||
in.iceInProgress = false
|
||||
},
|
||||
want: guard.ConnStatusDisconnected,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
in := base
|
||||
tc.mutator(&in)
|
||||
if got := evalConnStatus(in); got != tc.want {
|
||||
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||
)
|
||||
|
||||
func isForceRelayed() bool {
|
||||
func IsForceRelayed() bool {
|
||||
if runtime.GOOS == "js" {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -8,7 +8,19 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type isConnectedFunc func() bool
|
||||
// ConnStatus represents the connection state as seen by the guard.
|
||||
type ConnStatus int
|
||||
|
||||
const (
|
||||
// ConnStatusDisconnected means neither ICE nor Relay is connected.
|
||||
ConnStatusDisconnected ConnStatus = iota
|
||||
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
|
||||
ConnStatusPartiallyConnected
|
||||
// ConnStatusConnected means all required connections are established.
|
||||
ConnStatusConnected
|
||||
)
|
||||
|
||||
type connStatusFunc func() ConnStatus
|
||||
|
||||
// Guard is responsible for the reconnection logic.
|
||||
// It will trigger to send an offer to the peer then has connection issues.
|
||||
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
|
||||
// - ICE candidate changes
|
||||
type Guard struct {
|
||||
log *log.Entry
|
||||
isConnectedOnAllWay isConnectedFunc
|
||||
isConnectedOnAllWay connStatusFunc
|
||||
timeout time.Duration
|
||||
srWatcher *SRWatcher
|
||||
relayedConnDisconnected chan struct{}
|
||||
iCEConnDisconnected chan struct{}
|
||||
}
|
||||
|
||||
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
|
||||
return &Guard{
|
||||
log: log,
|
||||
isConnectedOnAllWay: isConnectedFn,
|
||||
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
|
||||
}
|
||||
}
|
||||
|
||||
// reconnectLoopWithRetry periodically check the connection status.
|
||||
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
|
||||
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
|
||||
//
|
||||
// Behavior depends on the connection state reported by isConnectedOnAllWay:
|
||||
// - Connected: no action, the peer is fully reachable.
|
||||
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
|
||||
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
|
||||
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
|
||||
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
|
||||
//
|
||||
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
|
||||
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
|
||||
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
srReconnectedChan := g.srWatcher.NewListener()
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
iceState := &iceRetryState{log: g.log}
|
||||
defer iceState.reset()
|
||||
|
||||
for {
|
||||
select {
|
||||
case t := <-tickerChannel:
|
||||
if t.IsZero() {
|
||||
g.log.Infof("retry timed out, stop periodic offer sending")
|
||||
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
|
||||
tickerChannel = make(<-chan time.Time)
|
||||
continue
|
||||
case <-tickerChannel:
|
||||
switch g.isConnectedOnAllWay() {
|
||||
case ConnStatusConnected:
|
||||
// all good, nothing to do
|
||||
case ConnStatusDisconnected:
|
||||
callback()
|
||||
case ConnStatusPartiallyConnected:
|
||||
if iceState.shouldRetry() {
|
||||
callback()
|
||||
} else {
|
||||
iceState.enterHourlyMode()
|
||||
ticker.Stop()
|
||||
tickerChannel = iceState.hourlyC()
|
||||
}
|
||||
}
|
||||
|
||||
if !g.isConnectedOnAllWay() {
|
||||
callback()
|
||||
}
|
||||
case <-g.relayedConnDisconnected:
|
||||
g.log.Debugf("Relay connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-g.iCEConnDisconnected:
|
||||
g.log.Debugf("ICE connection changed, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-srReconnectedChan:
|
||||
g.log.Debugf("has network changes, reset reconnection ticker")
|
||||
ticker.Stop()
|
||||
ticker = g.prepareExponentTicker(ctx)
|
||||
ticker = g.newReconnectTicker(ctx)
|
||||
tickerChannel = ticker.C
|
||||
iceState.reset()
|
||||
|
||||
case <-ctx.Done():
|
||||
g.log.Debugf("context is done, stop reconnect loop")
|
||||
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 0.1,
|
||||
|
||||
61
client/internal/peer/guard/ice_retry_state.go
Normal file
61
client/internal/peer/guard/ice_retry_state.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
|
||||
maxICERetries = 3
|
||||
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
|
||||
iceRetryInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
|
||||
// After maxICERetries attempts it switches to a periodic hourly retry.
|
||||
type iceRetryState struct {
|
||||
log *log.Entry
|
||||
retries int
|
||||
hourly *time.Ticker
|
||||
}
|
||||
|
||||
func (s *iceRetryState) reset() {
|
||||
s.retries = 0
|
||||
if s.hourly != nil {
|
||||
s.hourly.Stop()
|
||||
s.hourly = nil
|
||||
}
|
||||
}
|
||||
|
||||
// shouldRetry reports whether the caller should send another ICE offer on this tick.
|
||||
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
|
||||
// to the hourly ticker via enterHourlyMode + hourlyC.
|
||||
func (s *iceRetryState) shouldRetry() bool {
|
||||
if s.hourly != nil {
|
||||
s.log.Debugf("hourly ICE retry attempt")
|
||||
return true
|
||||
}
|
||||
|
||||
s.retries++
|
||||
if s.retries <= maxICERetries {
|
||||
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
|
||||
func (s *iceRetryState) enterHourlyMode() {
|
||||
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
|
||||
s.hourly = time.NewTicker(iceRetryInterval)
|
||||
}
|
||||
|
||||
func (s *iceRetryState) hourlyC() <-chan time.Time {
|
||||
if s.hourly == nil {
|
||||
return nil
|
||||
}
|
||||
return s.hourly.C
|
||||
}
|
||||
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
103
client/internal/peer/guard/ice_retry_state_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func newTestRetryState() *iceRetryState {
|
||||
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
|
||||
}
|
||||
|
||||
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
for i := 0; i < maxICERetries; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
if s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if s.hourlyC() == nil {
|
||||
t.Fatalf("hourlyC returned nil after enterHourlyMode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.enterHourlyMode()
|
||||
defer s.reset()
|
||||
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
|
||||
// Subsequent calls also return true — we keep retrying on each hourly tick.
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
for i := 0; i < maxICERetries+1; i++ {
|
||||
_ = s.shouldRetry()
|
||||
}
|
||||
s.enterHourlyMode()
|
||||
|
||||
s.reset()
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC returned non-nil channel after reset")
|
||||
}
|
||||
if s.retries != 0 {
|
||||
t.Fatalf("retries = %d after reset, want 0", s.retries)
|
||||
}
|
||||
|
||||
for i := 1; i <= maxICERetries; i++ {
|
||||
if !s.shouldRetry() {
|
||||
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
|
||||
s := newTestRetryState()
|
||||
s.reset()
|
||||
s.reset() // second call must not panic or re-stop a nil ticker
|
||||
|
||||
if s.hourlyC() != nil {
|
||||
t.Fatalf("hourlyC non-nil after double reset")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
|
||||
return srw
|
||||
}
|
||||
|
||||
func (w *SRWatcher) Start() {
|
||||
func (w *SRWatcher) Start(disableICEMonitor bool) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
if !disableICEMonitor {
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
}
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -43,6 +44,10 @@ type OfferAnswer struct {
|
||||
SessionID *ICESessionID
|
||||
}
|
||||
|
||||
func (o *OfferAnswer) hasICECredentials() bool {
|
||||
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
@@ -59,6 +64,10 @@ type Handshaker struct {
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
|
||||
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
|
||||
remoteICESupported atomic.Bool
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
|
||||
@@ -66,7 +75,7 @@ type Handshaker struct {
|
||||
}
|
||||
|
||||
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
|
||||
return &Handshaker{
|
||||
h := &Handshaker{
|
||||
log: log,
|
||||
config: config,
|
||||
signaler: signaler,
|
||||
@@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
remoteOffersCh: make(chan OfferAnswer),
|
||||
remoteAnswerCh: make(chan OfferAnswer),
|
||||
}
|
||||
// assume remote supports ICE until we learn otherwise from received offers
|
||||
h.remoteICESupported.Store(ice != nil)
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *Handshaker) RemoteICESupported() bool {
|
||||
return h.remoteICESupported.Load()
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
@@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
@@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
|
||||
|
||||
// Record signaling received for reconnection attempts
|
||||
if h.metricsStages != nil {
|
||||
h.metricsStages.RecordSignalingReceived()
|
||||
}
|
||||
|
||||
h.updateRemoteICEState(&remoteOfferAnswer)
|
||||
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
if h.iceListener != nil && h.RemoteICESupported() {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
@@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error {
|
||||
}
|
||||
|
||||
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer := OfferAnswer{
|
||||
IceCredentials: IceCredentials{uFrag, pwd},
|
||||
WgListenPort: h.config.LocalWgPort,
|
||||
Version: version.NetbirdVersion(),
|
||||
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
|
||||
RosenpassAddr: h.config.RosenpassConfig.Addr,
|
||||
SessionID: &sid,
|
||||
}
|
||||
|
||||
if h.ice != nil && h.RemoteICESupported() {
|
||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||
sid := h.ice.SessionID()
|
||||
answer.IceCredentials = IceCredentials{uFrag, pwd}
|
||||
answer.SessionID = &sid
|
||||
}
|
||||
|
||||
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||
@@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
||||
|
||||
return answer
|
||||
}
|
||||
|
||||
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
|
||||
hasICE := offer.hasICECredentials()
|
||||
prev := h.remoteICESupported.Swap(hasICE)
|
||||
if prev != hasICE {
|
||||
if hasICE {
|
||||
h.log.Infof("remote peer started sending ICE credentials")
|
||||
} else {
|
||||
h.log.Infof("remote peer stopped sending ICE credentials")
|
||||
if h.ice != nil {
|
||||
h.ice.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool {
|
||||
|
||||
// SignalOfferAnswer signals either an offer or an answer to remote peer
|
||||
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
|
||||
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
var sessionIDBytes []byte
|
||||
if offerAnswer.SessionID != nil {
|
||||
var err error
|
||||
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
|
||||
if err != nil {
|
||||
log.Warnf("failed to get session ID bytes: %v", err)
|
||||
}
|
||||
}
|
||||
msg, err := signal.MarshalCredential(
|
||||
s.wgPrivateKey,
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -537,7 +536,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(ctx, localConn, remoteAddr)
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -549,7 +548,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local port forwarding: close local connection: %v", err)
|
||||
@@ -572,7 +571,7 @@ func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, rem
|
||||
}
|
||||
}()
|
||||
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
@@ -654,19 +653,16 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan, ok := <-channelRequests:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case newChan := <-channelRequests:
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -679,14 +675,8 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
|
||||
// channel open indefinitely; the relay itself runs under the outer ctx.
|
||||
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
|
||||
var dialer net.Dialer
|
||||
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
|
||||
cancelDial()
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
if err != nil {
|
||||
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -695,7 +685,7 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New
|
||||
}
|
||||
}()
|
||||
|
||||
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
|
||||
@@ -194,3 +194,63 @@ func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
return addresses
|
||||
}
|
||||
|
||||
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
|
||||
// It waits for both directions to complete before returning.
|
||||
// The caller is responsible for closing the connections.
|
||||
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
<-done
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
|
||||
}
|
||||
|
||||
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
|
||||
// It waits for both directions to complete or for context cancellation before returning.
|
||||
// Both connections are closed when the function returns.
|
||||
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (1->2): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
|
||||
logger.Debugf("copy error (2->1): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -353,7 +352,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
@@ -592,7 +591,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
|
||||
}
|
||||
go cryptossh.DiscardRequests(clientReqs)
|
||||
|
||||
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
|
||||
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const privilegedPortThreshold = 1024
|
||||
@@ -356,7 +356,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
return
|
||||
}
|
||||
|
||||
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
|
||||
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -27,7 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -54,10 +52,6 @@ const (
|
||||
DefaultJWTMaxTokenAge = 10 * 60
|
||||
)
|
||||
|
||||
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
|
||||
// the forwarded destination before rejecting the SSH channel.
|
||||
const directTCPIPDialTimeout = 30 * time.Second
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
@@ -897,29 +891,5 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
|
||||
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
|
||||
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
|
||||
}
|
||||
|
||||
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
|
||||
// DirectTCPIPHandler. The upstream handler closes both sides on the first
|
||||
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
|
||||
// own terms.
|
||||
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
|
||||
dest := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
|
||||
dconn, err := dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
_ = dconn.Close()
|
||||
return
|
||||
}
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
|
||||
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
|
||||
@@ -119,6 +119,8 @@ server:
|
||||
|
||||
# Reverse proxy settings (optional)
|
||||
# reverseProxy:
|
||||
# trustedHTTPProxies: []
|
||||
# trustedHTTPProxiesCount: 0
|
||||
# trustedPeers: []
|
||||
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
|
||||
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
|
||||
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
|
||||
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
|
||||
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
|
||||
|
||||
@@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
|
||||
// which is when Receive has actually returned. Without this, the Receive goroutine
|
||||
// can outlive the test and call t.Logf after teardown, panicking.
|
||||
receiveDone := make(chan struct{})
|
||||
t.Cleanup(func() {
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Receive goroutine did not exit after Close")
|
||||
}
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
@@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
|
||||
receivedAfterReconnect := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer close(receiveDone)
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if msg.IsInitiator || len(msg.EventId) == 0 {
|
||||
return nil
|
||||
|
||||
2
go.mod
2
go.mod
@@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
|
||||
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
|
||||
|
||||
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
|
||||
|
||||
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
|
||||
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
|
||||
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
|
||||
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
|
||||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
@@ -449,6 +447,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
|
||||
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
|
||||
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
|
||||
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
|
||||
@@ -472,7 +472,7 @@ start_services_and_show_instructions() {
|
||||
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
|
||||
echo "Registering CrowdSec bouncer..."
|
||||
local cs_retries=0
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
|
||||
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
|
||||
cs_retries=$((cs_retries + 1))
|
||||
if [[ $cs_retries -ge 30 ]]; then
|
||||
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -87,17 +88,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
switch authType {
|
||||
case "bearer":
|
||||
request, err := m.checkJWTFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, request)
|
||||
h.ServeHTTP(w, r)
|
||||
case "token":
|
||||
request, err := m.checkPATFromRequest(r, authHeader)
|
||||
if err != nil {
|
||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
@@ -106,7 +104,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
h.ServeHTTP(w, r)
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||
return
|
||||
@@ -115,19 +113,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
// CheckJWTFromRequest checks if the JWT is valid
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||
|
||||
// If an error occurs, call the error handler and return an error
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
@@ -143,7 +141,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
if userAuth.AccountId != accountId {
|
||||
@@ -153,7 +151,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
|
||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
@@ -164,17 +162,19 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPATFromRequest checks if the PAT is valid
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
|
||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
return fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.patUsageTracker != nil {
|
||||
@@ -183,22 +183,22 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
return status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
return r, fmt.Errorf("invalid Token: %w", err)
|
||||
return fmt.Errorf("invalid Token: %w", err)
|
||||
}
|
||||
if time.Now().After(pat.GetExpirationDate()) {
|
||||
return r, fmt.Errorf("token expired")
|
||||
return fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
||||
if err != nil {
|
||||
return r, err
|
||||
return err
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
@@ -216,7 +216,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
}
|
||||
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
// propagates ctx change to upstream middleware
|
||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||
return nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
_ "embed"
|
||||
|
||||
"github.com/rs/xid"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var action = activity.PolicyAdded
|
||||
var unchanged bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
|
||||
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
saveFunc := transaction.CreatePolicy
|
||||
if isUpdate {
|
||||
action = activity.PolicyUpdated
|
||||
saveFunc = transaction.SavePolicy
|
||||
}
|
||||
if policy.Equal(existingPolicy) {
|
||||
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
|
||||
unchanged = true
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = saveFunc(ctx, policy); err != nil {
|
||||
return err
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
@@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if unchanged {
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
@@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
|
||||
if isUpdate {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
@@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
|
||||
// validatePolicy validates the policy and its rules. For updates it returns
|
||||
// the existing policy loaded from the store so callers can avoid a second read.
|
||||
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {
|
||||
var existingPolicy *types.Policy
|
||||
if policy.ID != "" {
|
||||
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
var err error
|
||||
existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: Refactor to support multiple rules per policy
|
||||
@@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.ID != "" && !existingRuleIDs[rule.ID] {
|
||||
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, rule := range policy.Rules {
|
||||
@@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
|
||||
policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
|
||||
}
|
||||
|
||||
return nil
|
||||
return existingPolicy, nil
|
||||
}
|
||||
|
||||
// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
|
||||
|
||||
@@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
}
|
||||
})
|
||||
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
// Hold on to req so auth's in-place ctx update is visible after ServeHTTP.
|
||||
req := r.WithContext(ctx)
|
||||
h.ServeHTTP(w, req)
|
||||
close(handlerDone)
|
||||
|
||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
||||
if err == nil {
|
||||
if userAuth.AccountId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId)
|
||||
}
|
||||
if userAuth.UserId != "" {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId)
|
||||
}
|
||||
}
|
||||
ctx = req.Context()
|
||||
|
||||
if w.Status() > 399 {
|
||||
log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
|
||||
|
||||
@@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy {
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *Policy) Equal(other *Policy) bool {
|
||||
if p == nil || other == nil {
|
||||
return p == other
|
||||
}
|
||||
|
||||
if p.ID != other.ID ||
|
||||
p.AccountID != other.AccountID ||
|
||||
p.Name != other.Name ||
|
||||
p.Description != other.Description ||
|
||||
p.Enabled != other.Enabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(p.Rules) != len(other.Rules) {
|
||||
return false
|
||||
}
|
||||
|
||||
otherRules := make(map[string]*PolicyRule, len(other.Rules))
|
||||
for _, r := range other.Rules {
|
||||
otherRules[r.ID] = r
|
||||
}
|
||||
for _, r := range p.Rules {
|
||||
otherRule, ok := otherRules[r.ID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !r.Equal(otherRule) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to this policy
|
||||
func (p *Policy) EventMeta() map[string]any {
|
||||
return map[string]any{"name": p.Name}
|
||||
|
||||
193
management/server/types/policy_test.go
Normal file
193
management/server/types/policy_test.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r2", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentRules(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"80"}},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1", Ports: []string{"443"}},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentRuleCount(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
{ID: "r2", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc3", "pc1", "pc2"},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc2", "pc3"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentPostureChecks(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
SourcePostureChecks: []string{"pc1", "pc3"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_DifferentScalarFields(t *testing.T) {
|
||||
base := Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "test",
|
||||
Description: "desc",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
other := base
|
||||
other.Name = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Enabled = false
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Description = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_NilCases(t *testing.T) {
|
||||
var a *Policy
|
||||
var b *Policy
|
||||
assert.True(t, a.Equal(b))
|
||||
|
||||
a = &Policy{ID: "pol1"}
|
||||
assert.False(t, a.Equal(nil))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_RulesMismatchByID(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r1", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
Rules: []*PolicyRule{
|
||||
{ID: "r2", PolicyID: "pol1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyEqual_FullScenario(t *testing.T) {
|
||||
a := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "Web Access",
|
||||
Description: "Allow web access",
|
||||
Enabled: true,
|
||||
SourcePostureChecks: []string{"pc2", "pc1"},
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "r1",
|
||||
PolicyID: "pol1",
|
||||
Name: "HTTP",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"g2", "g1"},
|
||||
Destinations: []string{"g4", "g3"},
|
||||
Ports: []string{"443", "80", "8080"},
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 8000, End: 9000},
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b := &Policy{
|
||||
ID: "pol1",
|
||||
AccountID: "acc1",
|
||||
Name: "Web Access",
|
||||
Description: "Allow web access",
|
||||
Enabled: true,
|
||||
SourcePostureChecks: []string{"pc1", "pc2"},
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "r1",
|
||||
PolicyID: "pol1",
|
||||
Name: "HTTP",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Bidirectional: true,
|
||||
Sources: []string{"g1", "g2"},
|
||||
Destinations: []string{"g3", "g4"},
|
||||
Ports: []string{"80", "8080", "443"},
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 8000, End: 9000},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func (pm *PolicyRule) Equal(other *PolicyRule) bool {
|
||||
if pm == nil || other == nil {
|
||||
return pm == other
|
||||
}
|
||||
|
||||
if pm.ID != other.ID ||
|
||||
pm.PolicyID != other.PolicyID ||
|
||||
pm.Name != other.Name ||
|
||||
pm.Description != other.Description ||
|
||||
pm.Enabled != other.Enabled ||
|
||||
pm.Action != other.Action ||
|
||||
pm.Bidirectional != other.Bidirectional ||
|
||||
pm.Protocol != other.Protocol ||
|
||||
pm.SourceResource != other.SourceResource ||
|
||||
pm.DestinationResource != other.DestinationResource ||
|
||||
pm.AuthorizedUser != other.AuthorizedUser {
|
||||
return false
|
||||
}
|
||||
|
||||
if !stringSlicesEqualUnordered(pm.Sources, other.Sources) {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(pm.Ports, other.Ports) {
|
||||
return false
|
||||
}
|
||||
if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) {
|
||||
return false
|
||||
}
|
||||
if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func stringSlicesEqualUnordered(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
sorted1 := make([]string, len(a))
|
||||
sorted2 := make([]string, len(b))
|
||||
copy(sorted1, a)
|
||||
copy(sorted2, b)
|
||||
slices.Sort(sorted1)
|
||||
slices.Sort(sorted2)
|
||||
return slices.Equal(sorted1, sorted2)
|
||||
}
|
||||
|
||||
func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return true
|
||||
}
|
||||
cmp := func(x, y RulePortRange) int {
|
||||
if x.Start != y.Start {
|
||||
if x.Start < y.Start {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
if x.End != y.End {
|
||||
if x.End < y.End {
|
||||
return -1
|
||||
}
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
sorted1 := make([]RulePortRange, len(a))
|
||||
sorted2 := make([]RulePortRange, len(b))
|
||||
copy(sorted1, a)
|
||||
copy(sorted2, b)
|
||||
slices.SortFunc(sorted1, cmp)
|
||||
slices.SortFunc(sorted2, cmp)
|
||||
return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool {
|
||||
return x.Start == y.Start && x.End == y.End
|
||||
})
|
||||
}
|
||||
|
||||
func authorizedGroupsEqual(a, b map[string][]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, va := range a {
|
||||
vb, ok := b[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !stringSlicesEqualUnordered(va, vb) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
194
management/server/types/policyrule_test.go
Normal file
194
management/server/types/policyrule_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "80", "22"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"22", "443", "80"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentPorts(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "80"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{"443", "22"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g2", "g3"},
|
||||
Destinations: []string{"g4", "g5"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g3", "g1", "g2"},
|
||||
Destinations: []string{"g5", "g4"},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentSources(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g2"},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Sources: []string{"g1", "g3"},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 8000, End: 9000},
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
{Start: 8000, End: 9000},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 80},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
PortRanges: []RulePortRange{
|
||||
{Start: 80, End: 443},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u1", "u2", "u3"},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u3", "u1", "u2"},
|
||||
},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g1": {"u1"},
|
||||
},
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
AuthorizedGroups: map[string][]string{
|
||||
"g2": {"u1"},
|
||||
},
|
||||
}
|
||||
assert.False(t, a.Equal(b))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) {
|
||||
base := PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Name: "test",
|
||||
Description: "desc",
|
||||
Enabled: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
}
|
||||
|
||||
other := base
|
||||
other.Name = "changed"
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Enabled = false
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Action = PolicyTrafficActionDrop
|
||||
assert.False(t, base.Equal(&other))
|
||||
|
||||
other = base
|
||||
other.Protocol = PolicyRuleProtocolUDP
|
||||
assert.False(t, base.Equal(&other))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_NilCases(t *testing.T) {
|
||||
var a *PolicyRule
|
||||
var b *PolicyRule
|
||||
assert.True(t, a.Equal(b))
|
||||
|
||||
a = &PolicyRule{ID: "rule1"}
|
||||
assert.False(t, a.Equal(nil))
|
||||
}
|
||||
|
||||
func TestPolicyRuleEqual_EmptySlices(t *testing.T) {
|
||||
a := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: []string{},
|
||||
Sources: nil,
|
||||
}
|
||||
b := &PolicyRule{
|
||||
ID: "rule1",
|
||||
PolicyID: "pol1",
|
||||
Ports: nil,
|
||||
Sources: []string{},
|
||||
}
|
||||
assert.True(t, a.Equal(b))
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ func (c *peekedConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// halfCloser matches connections that support shutting down the write
|
||||
// side while keeping the read side open (e.g. *net.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// CloseWrite delegates to the underlying connection if it supports
|
||||
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
|
||||
// as an interface hides the concrete type's CloseWrite method, making
|
||||
|
||||
156
proxy/internal/tcp/relay.go
Normal file
156
proxy/internal/tcp/relay.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
|
||||
var errIdleTimeout = errors.New("idle timeout")
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
|
||||
// A zero value disables idle timeout checking.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
|
||||
// EOF to the remote by closing the write side while keeping the read
|
||||
// side open so the other direction can drain.
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Relay copies data bidirectionally between src and dst until both
|
||||
// sides are done or the context is canceled. When idleTimeout is
|
||||
// non-zero, each direction's read is deadline-guarded; if no data
|
||||
// flows within the timeout the connection is torn down. When one
|
||||
// direction finishes, it half-closes the write side of the
|
||||
// destination (if supported) to signal EOF, allowing the other
|
||||
// direction to drain gracefully before the full connection teardown.
|
||||
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errSrcToDst, errDstToSrc error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
|
||||
halfClose(dst)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
|
||||
halfClose(src)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
|
||||
logger.Debug("relay closed due to idle timeout")
|
||||
}
|
||||
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
|
||||
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
|
||||
}
|
||||
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
|
||||
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
|
||||
}
|
||||
|
||||
return srcToDst, dstToSrc
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
|
||||
// When idleTimeout > 0 it sets a read deadline on src before each
|
||||
// read and treats a timeout as an idle-triggered close.
|
||||
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
conn, ok := src.(net.Conn)
|
||||
if !ok {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return total, err
|
||||
}
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
n, err := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if netutil.IsTimeout(readErr) {
|
||||
return total, errIdleTimeout
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkedWrite writes buf to dst and returns the number of bytes written.
|
||||
// It guards against short writes and negative counts per io.Copy convention.
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
|
||||
}
|
||||
|
||||
// halfClose attempts to half-close the write side of the connection.
|
||||
// If the connection does not support half-close, this is a no-op.
|
||||
func halfClose(conn net.Conn) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
// Best-effort; the full close will follow shortly.
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
@@ -13,13 +13,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
func testRelay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (int64, int64) {
|
||||
return netrelay.Relay(ctx, src, dst, netrelay.Options{IdleTimeout: idleTimeout, Logger: logger})
|
||||
}
|
||||
|
||||
func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
@@ -46,7 +41,7 @@ func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
s2d, d2s := testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
|
||||
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
|
||||
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
|
||||
@@ -63,7 +58,7 @@ func TestRelay_ContextCancellation(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -90,7 +85,7 @@ func TestRelay_OneSideClosed(t *testing.T) {
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -134,7 +129,7 @@ func TestRelay_LargeTransfer(t *testing.T) {
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
s2d, _ := testRelay(ctx, logger, srcServer, dstServer, 0)
|
||||
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
|
||||
require.NoError(t, <-errCh)
|
||||
}
|
||||
@@ -187,7 +182,7 @@ func TestRelay_IdleTimeout(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
var s2d, d2s int64
|
||||
go func() {
|
||||
s2d, d2s = testRelay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/restrict"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
)
|
||||
|
||||
// defaultDialTimeout is the fallback dial timeout when no per-route
|
||||
@@ -529,14 +528,11 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
|
||||
|
||||
idleTimeout := route.SessionIdleTimeout
|
||||
if idleTimeout <= 0 {
|
||||
idleTimeout = netrelay.DefaultIdleTimeout
|
||||
idleTimeout = DefaultIdleTimeout
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{
|
||||
IdleTimeout: idleTimeout,
|
||||
Logger: entry,
|
||||
})
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
|
||||
@@ -30,6 +30,8 @@ import (
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
const (
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||
// for the management client connection. Value is in bytes.
|
||||
@@ -532,7 +534,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
||||
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
type ConnStateNotifier interface {
|
||||
MarkSignalDisconnected(error)
|
||||
@@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
|
||||
Key: c.key.PublicKey().String(),
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
// Package netrelay provides a bidirectional byte-copy helper for TCP-like
|
||||
// connections with correct half-close propagation.
|
||||
//
|
||||
// When one direction reads EOF, the write side of the opposite connection is
|
||||
// half-closed (CloseWrite) so the peer sees FIN, then the second direction is
|
||||
// allowed to drain to its own EOF before both connections are fully closed.
|
||||
// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the
|
||||
// naive "cancel-both-on-first-EOF" pattern breaks.
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DebugLogger is the minimal interface netrelay uses to surface teardown
|
||||
// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method)
|
||||
// satisfy it, so callers can pass whichever they already use without an
|
||||
// adapter. Debugf is the only required method; callers with richer
|
||||
// loggers just expose this one shape here.
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
}
|
||||
|
||||
// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers
|
||||
// that want an idle timeout but have no specific preference can use this.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn, *gonet.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Options configures Relay behavior. The zero value is valid: no idle timeout,
|
||||
// no logging.
|
||||
type Options struct {
|
||||
// IdleTimeout tears down the session if no bytes flow in either
|
||||
// direction within this window. It is a connection-wide watchdog, so a
|
||||
// long unidirectional transfer on one side keeps the other side alive.
|
||||
// Zero disables idle tracking.
|
||||
IdleTimeout time.Duration
|
||||
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
|
||||
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
|
||||
// uspfilter's nblog.Logger, etc.).
|
||||
Logger DebugLogger
|
||||
}
|
||||
|
||||
// Relay copies bytes in both directions between a and b until both directions
|
||||
// EOF or ctx is canceled. On each direction's EOF it half-closes the
|
||||
// opposite conn's write side (best effort) so the peer sees FIN while the
|
||||
// other direction drains. Both conns are fully closed when Relay returns.
|
||||
//
|
||||
// a and b only need to implement io.ReadWriteCloser; connections that also
|
||||
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
|
||||
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
|
||||
// watchdog that tracks reads in either direction.
|
||||
//
|
||||
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
|
||||
// a.Write). Errors are logged via Options.Logger when set; they are not
|
||||
// returned because a relay always terminates on some kind of EOF/cancel.
|
||||
func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
closeDone := make(chan struct{})
|
||||
defer func() {
|
||||
cancel()
|
||||
<-closeDone
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
// Both sides must support CloseWrite to propagate half-close. If either
|
||||
// doesn't, a direction's EOF can't be signaled to the peer and the other
|
||||
// direction would block forever waiting for data; in that case we fall
|
||||
// back to the cancel-both-on-first-EOF behavior.
|
||||
_, aHC := a.(halfCloser)
|
||||
_, bHC := b.(halfCloser)
|
||||
halfCloseSupported := aHC && bHC
|
||||
|
||||
var (
|
||||
lastActivity atomic.Int64
|
||||
idleHit atomic.Bool
|
||||
)
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
if opts.IdleTimeout > 0 {
|
||||
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errAToB, errBToA error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
aToB, errAToB = copyTracked(b, a, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errAToB) {
|
||||
halfClose(b)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bToA, errBToA = copyTracked(a, b, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errBToA) {
|
||||
halfClose(a)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if opts.Logger != nil {
|
||||
if idleHit.Load() {
|
||||
opts.Logger.Debugf("relay closed due to idle timeout")
|
||||
}
|
||||
if errAToB != nil && !isExpectedCopyError(errAToB) {
|
||||
opts.Logger.Debugf("relay copy error (a→b): %v", errAToB)
|
||||
}
|
||||
if errBToA != nil && !isExpectedCopyError(errBToA) {
|
||||
opts.Logger.Debugf("relay copy error (b→a): %v", errBToA)
|
||||
}
|
||||
}
|
||||
|
||||
return aToB, bToA
|
||||
}
|
||||
|
||||
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
|
||||
// activity has been seen on either direction for idle. It exits as soon as
|
||||
// ctx is canceled so it doesn't outlive the relay.
|
||||
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
|
||||
// Cap the tick at 50ms so detection latency stays bounded regardless of
|
||||
// how large idle is, and fall back to idle/2 when that is smaller so
|
||||
// very short timeouts (mainly in tests) are still caught promptly.
|
||||
tick := min(idle/2, 50*time.Millisecond)
|
||||
if tick <= 0 {
|
||||
tick = time.Millisecond
|
||||
}
|
||||
t := time.NewTicker(tick)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
last := time.Unix(0, lastActivity.Load())
|
||||
if time.Since(last) >= idle {
|
||||
idleHit.Store(true)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTracked copies from src to dst using a pooled buffer, updating
|
||||
// lastActivity on every successful read so a shared watchdog can enforce a
|
||||
// connection-wide idle timeout.
|
||||
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
n, werr := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if werr != nil {
|
||||
return total, werr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func halfClose(conn io.ReadWriteCloser) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
// isCleanEOF reports whether a copy terminated on a graceful end-of-stream.
|
||||
// Only in that case is it correct to propagate the EOF via CloseWrite on the
|
||||
// peer; any other error means the flow is broken and both directions should
|
||||
// tear down.
|
||||
func isCleanEOF(err error) bool {
|
||||
return err == nil || errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, syscall.ECONNRESET) ||
|
||||
errors.Is(err, syscall.EPIPE) ||
|
||||
errors.Is(err, syscall.ECONNABORTED)
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// tcpPair returns two connected loopback TCP conns.
|
||||
func tcpPair(t *testing.T) (*net.TCPConn, *net.TCPConn) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
type result struct {
|
||||
c *net.TCPConn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
ch <- result{nil, err}
|
||||
return
|
||||
}
|
||||
ch <- result{c.(*net.TCPConn), nil}
|
||||
}()
|
||||
|
||||
dial, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
r := <-ch
|
||||
require.NoError(t, r.err)
|
||||
return dial.(*net.TCPConn), r.c
|
||||
}
|
||||
|
||||
// TestRelayHalfClose exercises the shutdown(SHUT_WR) scenario that the naive
|
||||
// cancel-both-on-first-EOF pattern breaks. Client A shuts down its write
|
||||
// side; B must still be able to write a full response and A must receive
|
||||
// all of it before its read returns EOF.
|
||||
func TestRelayHalfClose(t *testing.T) {
|
||||
// Real peer pairs for each side of the relay. We relay between relayA
|
||||
// and relayB. Peer A talks through relayA; peer B talks through relayB.
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
// Bound blocking reads/writes so a broken relay fails the test instead of
|
||||
// hanging the test process.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
require.NoError(t, peerA.SetDeadline(deadline))
|
||||
require.NoError(t, peerB.SetDeadline(deadline))
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Peer A sends a request, then half-closes its write side.
|
||||
req := []byte("request-payload")
|
||||
_, err := peerA.Write(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, peerA.CloseWrite())
|
||||
|
||||
// Peer B reads the request to EOF (FIN must have propagated).
|
||||
got, err := io.ReadAll(peerB)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, req, got)
|
||||
|
||||
// Peer B writes its response; peer A must receive all of it even though
|
||||
// peer A's write side is already closed.
|
||||
resp := make([]byte, 64*1024)
|
||||
for i := range resp {
|
||||
resp[i] = byte(i)
|
||||
}
|
||||
_, err = peerB.Write(resp)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, peerB.Close())
|
||||
|
||||
gotResp, err := io.ReadAll(peerA)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resp, gotResp)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayFullDuplex verifies bidirectional copy in the simple case.
|
||||
func TestRelayFullDuplex(t *testing.T) {
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
// Bound blocking reads/writes so a broken relay fails the test instead of
|
||||
// hanging the test process.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
require.NoError(t, peerA.SetDeadline(deadline))
|
||||
require.NoError(t, peerB.SetDeadline(deadline))
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
type result struct {
|
||||
got []byte
|
||||
err error
|
||||
}
|
||||
resA := make(chan result, 1)
|
||||
resB := make(chan result, 1)
|
||||
|
||||
msgAB := []byte("hello-from-a")
|
||||
msgBA := []byte("hello-from-b")
|
||||
|
||||
go func() {
|
||||
if _, err := peerA.Write(msgAB); err != nil {
|
||||
resA <- result{err: err}
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len(msgBA))
|
||||
_, err := io.ReadFull(peerA, buf)
|
||||
resA <- result{got: buf, err: err}
|
||||
_ = peerA.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := peerB.Write(msgBA); err != nil {
|
||||
resB <- result{err: err}
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len(msgAB))
|
||||
_, err := io.ReadFull(peerB, buf)
|
||||
resB <- result{got: buf, err: err}
|
||||
_ = peerB.Close()
|
||||
}()
|
||||
|
||||
a, b := <-resA, <-resB
|
||||
require.NoError(t, a.err)
|
||||
require.Equal(t, msgBA, a.got)
|
||||
require.NoError(t, b.err)
|
||||
require.Equal(t, msgAB, b.got)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayNoHalfCloseFallback ensures Relay terminates when the underlying
|
||||
// conns don't support CloseWrite (e.g. net.Pipe). Without the fallback to
|
||||
// cancel-both-on-first-EOF, the second direction would block forever.
|
||||
func TestRelayNoHalfCloseFallback(t *testing.T) {
|
||||
a1, a2 := net.Pipe()
|
||||
b1, b2 := net.Pipe()
|
||||
defer a1.Close()
|
||||
defer b1.Close()
|
||||
|
||||
ctx := t.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, a2, b2, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Close peer A's side; a2's Read will return EOF.
|
||||
require.NoError(t, a1.Close())
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not terminate when half-close is unsupported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayIdleTimeout ensures the idle watchdog tears down a silent flow.
|
||||
func TestRelayIdleTimeout(t *testing.T) {
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
const idle = 150 * time.Millisecond
|
||||
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{IdleTimeout: idle})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not close on idle")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
require.GreaterOrEqual(t, elapsed, idle,
|
||||
"relay must not close before the idle timeout elapses")
|
||||
require.Less(t, elapsed, idle+500*time.Millisecond,
|
||||
"relay should close shortly after the idle timeout")
|
||||
}
|
||||
Reference in New Issue
Block a user