Compare commits

...

22 Commits

Author SHA1 Message Date
Viktor Liu
94a8b7325e Filter DNS fallback upstreams matching our server IP to prevent loops 2026-05-17 15:49:31 +02:00
Viktor Liu
3f91f49277 Clean up legacy 32-bit and HKCU registry entries on Windows install (#6176) 2026-05-16 16:52:57 +02:00
Maycon Santos
347c5bf317 Avoid context cancellation in cancelPeerRoutines (#6175)
When closing go routines and handling peer disconnect, we should avoid canceling the flow due to parent gRPC context cancellation.

This change triggers disconnection handling with a context that is not bound to the parent gRPC cancellation.
2026-05-16 16:29:01 +02:00
Viktor Liu
22e2519d71 [management] Avoid peer IP reallocation when account settings update preserves the network range (#6173) 2026-05-16 15:51:48 +02:00
Vlad
e916f12cca [proxy] auth token generation on mapping (#6157)
* [management / proxy] auth token generation on mapping

* fix tests
2026-05-15 19:13:44 +02:00
Viktor Liu
9ed2e2a5b4 [client] Drop DNS probes for passive health projection (#5971) 2026-05-15 17:07:38 +02:00
Viktor Liu
2ccae7ec47 [client] Mirror v4 exit selection onto v6 pair and honour SkipAutoApply per route (#6150) 2026-05-15 16:58:47 +02:00
Viktor Liu
07e5450117 [management] Bracket IPv6 reverse-proxy target hosts when building URL Host field (#6141) 2026-05-14 16:42:40 +02:00
Viktor Liu
3f914090cb [client] Bracket IPv6 in embed listeners, expand debug bundle (#6134) 2026-05-14 16:22:53 +02:00
Viktor Liu
ea9fab4396 [management] Allocate and preserve IPv6 overlay addresses for embedded proxy peers (#6132) 2026-05-14 16:05:33 +02:00
Vlad
77b479286e [management] fix offline statuses for public proxy clusters (#6133) 2026-05-14 13:27:50 +02:00
Maycon Santos
ab2a8794e7 [client] Add short flags for status command options (#6137)
* [client] Add short flags for status command options

* uppercase filters
2026-05-14 12:30:42 +02:00
Viktor Liu
9126a192ca [client] Set 0644 perms on SSH client config after os.CreateTemp (#6126) 2026-05-12 15:05:53 +02:00
Viktor Liu
1224d6e1ee [client] Persist management URL and pre-shared key overrides on login (#6065) 2026-05-12 14:52:56 +02:00
Nicolas Frati
96672dd1f8 [management] chores: update dex version (#6124)
* chores: update dex version

* chore: update dex fork
2026-05-12 13:50:35 +02:00
Viktor Liu
946ce4c3da [client] Fix --config flag default to point at profile path (#6122) 2026-05-11 17:48:21 +02:00
Vlad
07cbfdbede [proxy] feature: bring your own proxy (#5627) 2026-05-11 14:31:38 +02:00
Viktor Liu
a4114a5e45 [client] Skip DNS upstream failover on definitive EDE (#6089) 2026-05-11 10:00:23 +02:00
Viktor Liu
6b08e89c7b [relay] Preserve non-standard port in WS dialer URL prep (#6061) 2026-05-11 09:59:33 +02:00
Viktor Liu
a852b3bd34 [client, proxy] Harden uspfilter conntrack and share TCP relay (#5936) 2026-05-11 09:59:13 +02:00
Viktor Liu
afb83b3049 [client] Use unique temp file and clean up on failure when writing ssh config (#6064) 2026-05-11 09:58:49 +02:00
Nicolas Frati
e89aad09f5 [management] Enable MFA for local users (#5804)
* wip: totp for local users

* fix providers not getting populated

* polished UI and fix post_login_redirect_uri

* fix: make sure logout is only prompted from oidc flow

Signed-off-by: jnfrati <nicofrati@gmail.com>

* update templates

Signed-off-by: jnfrati <nicofrati@gmail.com>

* deps: update dex dependency

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fix qube issues

Signed-off-by: jnfrati <nicofrati@gmail.com>

* replace window with globalThis on home html

Signed-off-by: jnfrati <nicofrati@gmail.com>

* fixed coderabbit comments

Signed-off-by: jnfrati <nicofrati@gmail.com>

* debug

* remove unused config and rename totp issuer

* deps: update dex reference to latest

* add dashboard post logout redirect uri to embedded config

* implemented api for mfa configuration

* update docs and config parsing

* catch error on idp manager init mfa

* fix tests

* Add remember me  for MFA

* Add cookie encryption and session share between tabs

* fixed logout showing non actionable error and session cookie encription key

* fixed missing mfa settings on sql query for account

* fix code index for mfa activity

---------

Signed-off-by: jnfrati <nicofrati@gmail.com>
Co-authored-by: braginini <bangvalo@gmail.com>
2026-05-08 16:31:20 +02:00
127 changed files with 8519 additions and 1768 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -143,7 +143,7 @@ func init() {
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets WireGuard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Overrides the default profile file location")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", profilemanager.DefaultConfigPath, "Overrides the default profile file location")
rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd)

View File

@@ -43,16 +43,16 @@ func init() {
ipsFilterMap = make(map[string]struct{})
prefixNamesFilterMap = make(map[string]struct{})
statusCmd.PersistentFlags().BoolVarP(&detailFlag, "detail", "d", false, "display detailed status information in human-readable format")
statusCmd.PersistentFlags().BoolVar(&jsonFlag, "json", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVar(&yamlFlag, "yaml", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVar(&ipv4Flag, "ipv4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVar(&ipv6Flag, "ipv6", false, "display only NetBird IPv6 of this peer")
statusCmd.PersistentFlags().BoolVarP(&jsonFlag, "json", "j", false, "display detailed status information in json format")
statusCmd.PersistentFlags().BoolVarP(&yamlFlag, "yaml", "y", false, "display detailed status information in yaml format")
statusCmd.PersistentFlags().BoolVarP(&ipv4Flag, "ipv4", "4", false, "display only NetBird IPv4 of this peer, e.g., --ipv4 will output 100.64.0.33")
statusCmd.PersistentFlags().BoolVarP(&ipv6Flag, "ipv6", "6", false, "display only NetBird IPv6 of this peer")
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4", "ipv6")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
statusCmd.PersistentFlags().StringSliceVarP(&ipsFilter, "filter-by-ips", "I", []string{}, "filters the detailed output by a list of one or more IPs (v4 or v6), e.g., --filter-by-ips 100.64.0.100,fd00::1")
statusCmd.PersistentFlags().StringSliceVarP(&prefixNamesFilter, "filter-by-names", "N", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVarP(&statusFilter, "filter-by-status", "S", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVarP(&connectionTypeFilter, "filter-by-connection-type", "T", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
statusCmd.PersistentFlags().StringVarP(&checkFlag, "check", "C", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
}
func statusFunc(cmd *cobra.Command, args []string) error {

View File

@@ -336,7 +336,7 @@ func (c *Client) ListenTCP(address string) (net.Listener, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
@@ -357,7 +357,7 @@ func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
listenAddr := net.JoinHostPort(addr.String(), port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {

View File

@@ -0,0 +1,125 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
)
func TestTCPCapEvicts(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "4")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 4, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 10; i++ {
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
}
require.LessOrEqual(t, len(tracker.connections), 4,
"TCP table must not exceed the configured cap")
require.Greater(t, len(tracker.connections), 0,
"some entries must remain after eviction")
// The most recently admitted flow must be present: eviction must make
// room for new entries, not silently drop them.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
"newest TCP flow must be admitted after eviction")
// A pre-cap flow must have been evicted to fit the last one.
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
"oldest TCP flow should have been evicted")
}
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "3")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
// Fill to cap with 3 live connections.
for i := 0; i < 3; i++ {
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
}
require.Len(t, tracker.connections, 3)
// Tombstone one by sending RST through IsValidInbound.
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
// Another live connection forces eviction. The tombstone must go first.
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
require.False(t, tombstonedStillPresent,
"tombstoned entry should be evicted before live entries")
require.LessOrEqual(t, len(tracker.connections), 3)
// Both live pre-cap entries must survive: eviction must prefer the
// tombstone, not just satisfy the size bound by dropping any entry.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
}
func TestUDPCapEvicts(t *testing.T) {
t.Setenv(EnvUDPMaxEntries, "5")
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 5, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 12; i++ {
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
}
require.LessOrEqual(t, len(tracker.connections), 5)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
"newest UDP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
"oldest UDP flow should have been evicted")
}
func TestICMPCapEvicts(t *testing.T) {
t.Setenv(EnvICMPMaxEntries, "3")
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 3, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
for i := 0; i < 8; i++ {
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
}
require.LessOrEqual(t, len(tracker.connections), 3)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
"newest ICMP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
"oldest ICMP flow should have been evicted")
}

View File

@@ -3,15 +3,61 @@ package conntrack
import (
"net"
"net/netip"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// evictSampleSize bounds how many map entries we scan per eviction call.
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
// heuristic is good enough for a conntrack table that only overflows under
// abuse.
const evictSampleSize = 8
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
// def on empty or invalid; logs a warning on invalid.
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
v := os.Getenv(name)
if v == "" {
return def
}
d, err := time.ParseDuration(v)
if err != nil {
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
}
if d <= 0 {
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return d
}
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
// invalid, or non-positive. Logs a warning on invalid input.
func envInt(logger *nblog.Logger, name string, def int) int {
v := os.Getenv(name)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
switch {
case err != nil:
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
case n <= 0:
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return n
}
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
FlowId uuid.UUID

View File

@@ -0,0 +1,11 @@
//go:build !ios && !android
package conntrack
// Default per-tracker entry caps on desktop/server platforms. These mirror
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
const (
DefaultMaxTCPEntries = 65536
DefaultMaxUDPEntries = 16384
DefaultMaxICMPEntries = 2048
)

View File

@@ -0,0 +1,13 @@
//go:build ios || android
package conntrack
// Default per-tracker entry caps on mobile platforms. iOS network extensions
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
// is ~200 B plus map overhead).
const (
DefaultMaxTCPEntries = 4096
DefaultMaxUDPEntries = 2048
DefaultMaxICMPEntries = 512
)

View File

@@ -50,6 +50,9 @@ type ICMPConnTrack struct {
ICMPCode uint8
}
// EnvICMPMaxEntries caps the ICMP conntrack table size.
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
logger *nblog.Logger
@@ -58,6 +61,7 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -171,6 +175,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
flowLogger: flowLogger,
}
@@ -257,7 +262,9 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
@@ -276,10 +283,15 @@ func (t *ICMPTracker) track(
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
@@ -323,6 +335,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
}
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
func (t *ICMPTracker) evictOneLocked() {
var candKey ICMPConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
@@ -331,8 +371,10 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -38,6 +38,27 @@ const (
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
FinWaitTimeout = 60 * time.Second
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
// holding CloseWait longer than this should bump the env var.
CloseWaitTimeout = 60 * time.Second
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
LastAckTimeout = 30 * time.Second
)
// Env vars to override per-state teardown timeouts. Values parsed by
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
// defaults above with a warning.
const (
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
// (tombstones first) are evicted when the cap is reached.
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
)
// TCPState represents the state of a TCP connection
@@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() {
// TCPTracker manages TCP connection states
type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
timeout time.Duration
waitTimeout time.Duration
finWaitTimeout time.Duration
closeWaitTimeout time.Duration
lastAckTimeout time.Duration
maxEntries int
flowLogger nftypes.FlowLogger
}
// NewTCPTracker creates a new TCP connection tracker
@@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger,
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel,
timeout: timeout,
waitTimeout: waitTimeout,
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
flowLogger: flowLogger,
}
go tracker.cleanupRoutine(ctx)
@@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
if exists || flags&TCPSyn == 0 {
return
}
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
// create spurious conntrack entries. Not mandated by RFC 9293 but a
// common hardening (Linux netfilter/nftables rejects these too).
if !isValidFlagCombination(flags) {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
@@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
}
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
// iteration order is randomized), preferring a tombstone. If no tombstone
// found in the sample, evicts the oldest among the sampled entries. O(1)
// worst case — cheap enough to run on every insert at cap during abuse.
func (t *TCPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
if c.IsTombstone() {
delete(t.connections, k)
return
}
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
// TypeEnd is already emitted at the state transition to
// TimeWait and when a connection is tombstoned. Only emit
// here when we're reaping a still-active flow.
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
}
delete(t.connections, candKey)
}
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
@@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return false
}
// Reject illegal flag combinations regardless of state. These never belong
// to a legitimate flow and must not advance or tear down state.
if !isValidFlagCombination(flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
}
return false
}
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
}
return false
}
@@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return true
}
// updateState updates the TCP connection state based on flags
// updateState updates the TCP connection state based on flags.
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
// Malformed flag combinations must not refresh lastSeen or drive state,
// otherwise spoofed packets keep a dead flow alive past its timeout.
if !isValidFlagCombination(flags) {
return
}
conn.UpdateLastSeen()
currentState := conn.GetState()
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
// Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
// cannot apply the RFC 5961 in-window RST check, so we conservatively
// reject RSTs that the spec would accept (TIME-WAIT with in-window
// SEQ, SynSent from same direction as own SYN, etc.).
t.handleRst(key, conn, currentState, packetDir)
return
}
var newState TCPState
switch currentState {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
}
newState := nextState(currentState, conn.Direction, packetDir, flags)
if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
return
}
t.onTransition(key, conn, currentState, newState, packetDir)
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction {
newState = TCPStateEstablished
} else {
// Simultaneous open
newState = TCPStateSynReceived
}
}
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
// from the SYN direction are ignored; otherwise the flow is tombstoned.
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
// TimeWait exists to absorb late segments; don't let a late RST
// tombstone the entry and break same-4-tuple reuse.
if currentState == TCPStateTimeWait {
return
}
// A RST from the same direction as the SYN cannot be a legitimate
// response and must not tear down a half-open connection.
if currentState == TCPStateSynSent && packetDir == conn.Direction {
return
}
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
return
}
conn.SetTombstone()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
if packetDir == conn.Direction {
newState = TCPStateEstablished
}
}
// stateTransition describes one state's transition logic. It receives the
// packet's flags plus whether the packet direction matches the connection's
// origin direction (same=true means same side as the SYN initiator). Return 0
// for no transition.
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
case TCPStateEstablished:
if flags&TCPFin != 0 {
if packetDir == conn.Direction {
newState = TCPStateFinWait1
} else {
newState = TCPStateCloseWait
}
}
// stateTable maps each state to its transition function. Centralized here so
// nextState stays trivial and each rule is easy to read in isolation.
var stateTable = map[TCPState]stateTransition{
TCPStateNew: transNew,
TCPStateSynSent: transSynSent,
TCPStateSynReceived: transSynReceived,
TCPStateEstablished: transEstablished,
TCPStateFinWait1: transFinWait1,
TCPStateFinWait2: transFinWait2,
TCPStateClosing: transClosing,
TCPStateCloseWait: transCloseWait,
TCPStateLastAck: transLastAck,
}
case TCPStateFinWait1:
if packetDir != conn.Direction {
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0:
newState = TCPStateClosing
case flags&TCPAck != 0:
newState = TCPStateFinWait2
}
}
// nextState returns the target TCP state for the given current state and
// packet, or 0 if the packet does not trigger a transition.
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
fn, ok := stateTable[currentState]
if !ok {
return 0
}
return fn(flags, connDir, packetDir == connDir)
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
newState = TCPStateTimeWait
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if connDir == nftypes.Egress {
return TCPStateSynSent
}
return TCPStateSynReceived
}
return 0
}
case TCPStateClosing:
if flags&TCPAck != 0 {
newState = TCPStateTimeWait
func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if same {
return TCPStateSynReceived // simultaneous open
}
return TCPStateEstablished
}
return 0
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
newState = TCPStateLastAck
}
func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
return TCPStateEstablished
}
return 0
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
newState = TCPStateClosed
}
func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin == 0 {
return 0
}
if same {
return TCPStateFinWait1
}
return TCPStateCloseWait
}
// transFinWait1 handles the active-close peer response. A FIN carrying our
// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
if same {
return 0
}
if flags&TCPFin != 0 && flags&TCPAck != 0 {
return TCPStateTimeWait
}
switch {
case flags&TCPFin != 0:
return TCPStateClosing
case flags&TCPAck != 0:
return TCPStateFinWait2
}
return 0
}
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transClosing completes a simultaneous close on the peer's ACK.
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && same {
return TCPStateLastAck
}
return 0
}
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateClosed
}
return 0
}
// onTransition handles logging and flow-event emission after a successful
// state transition. TimeWait and Closed are terminal for flow accounting.
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
traceOn := t.logger.Enabled(nblog.LevelTrace)
if traceOn {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
switch to {
case TCPStateTimeWait:
if traceOn {
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
if traceOn {
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
// isValidStateForFlags checks if the TCP flags are valid for the current
// connection state. Caller must have already verified the flag combination is
// legal via isValidFlagCombination.
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
@@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() {
timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
timeout = t.finWaitTimeout
case TCPStateCloseWait:
timeout = t.closeWaitTimeout
case TCPStateLastAck:
timeout = t.lastAckTimeout
default:
// SynSent / SynReceived / New
timeout = TCPHandshakeTimeout
}
if conn.timeoutExceeded(timeout) {
delete(t.connections, key)
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
// event already handled by state change
if currentState != TCPStateTimeWait {

View File

@@ -0,0 +1,100 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// RST hygiene tests: the tracker currently closes the flow on any RST that
// matches the 4-tuple, regardless of direction or state. These tests cover
// the minimum checks we want (no SEQ tracking).
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateSynSent, conn.GetState())
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
// cannot be a legitimate response. It must not close the connection.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
require.Equal(t, TCPStateSynSent, conn.GetState(),
"RST in same direction as SYN must not close connection")
require.False(t, conn.IsTombstone())
}
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Drive to TIME-WAIT via active close.
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
// exists to absorb late segments).
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState(),
"RST in TIME-WAIT must not transition state")
require.False(t, conn.IsTombstone(),
"RST in TIME-WAIT must not tombstone the entry")
}
func TestTCPIllegalFlagCombos(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
// Illegal combos must be rejected and must not change state.
combos := []struct {
name string
flags uint8
}{
{"SYN+RST", TCPSyn | TCPRst},
{"FIN+RST", TCPFin | TCPRst},
{"SYN+FIN", TCPSyn | TCPFin},
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
}
for _, c := range combos {
t.Run(c.name, func(t *testing.T) {
before := conn.GetState()
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
require.Equal(t, before, conn.GetState(),
"illegal flag combo must not change state")
require.False(t, conn.IsTombstone())
})
}
}

View File

@@ -0,0 +1,235 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// These tests exercise cases where the TCP state machine currently advances
// on retransmitted or wrong-direction segments and tears the flow down
// prematurely. They are expected to fail until the direction checks are added.
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer sends FIN -> CloseWait (our app has not yet closed).
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
// sent our FIN yet, so state must remain CloseWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "retransmitted peer FIN must still be accepted")
require.Equal(t, TCPStateCloseWait, conn.GetState(),
"retransmitted peer FIN must not advance CloseWait to LastAck")
// Our app finally closes -> LastAck.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Peer ACK closes.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// We initiate close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Stray retransmit of our own FIN (same direction as originator) must
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait2, conn.GetState(),
"own FIN retransmit must not advance FinWait2 to TimeWait")
// Peer FIN -> TimeWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
}
func TestTCPLastAckDirectionCheck(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateLastAck, conn.GetState())
// Our own ACK retransmit (same direction as originator) must NOT close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState(),
"own ACK retransmit in LastAck must not transition to Closed")
// Peer's ACK -> Closed.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Our own ACK retransmit (same direction as originator) must not advance.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState(),
"own ACK in FinWait1 must not advance to FinWait2")
}
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
// Verify cleanup reaps entries in each teardown state at the configured
// per-state timeout, not at the single handshake timeout.
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
t.Setenv(EnvTCPLastAckTimeout, "30ms")
dstIP := netip.MustParseAddr("100.64.0.2")
dstPort := uint16(80)
// Drives a connection to the target state, forces its lastSeen well
// beyond the configured timeout, runs cleanup, and asserts reaping.
cases := []struct {
name string
// drive takes a fresh tracker and returns the conn key after
// transitioning the flow into the intended teardown state.
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
}{
{
name: "FinWait1",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
},
},
{
name: "FinWait2",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
},
},
{
name: "CloseWait",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
},
},
{
name: "LastAck",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
},
},
}
// Use a unique source port per subtest so nothing aliases.
port := uint16(12345)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
srcIP := netip.MustParseAddr("100.64.0.1")
port++
key, wantState := c.drive(t, tracker, srcIP, port)
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, wantState, conn.GetState())
// Age the entry past the largest per-state timeout.
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
tracker.cleanup()
_, exists := tracker.connections[key]
require.False(t, exists, "%s entry should be reaped", c.name)
})
}
}
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
// teardown states, which some stacks emit during close.
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer FIN -> CloseWait.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
// Peer pushes trailing data + FIN|PSH|ACK (legal).
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
"FIN|PSH|ACK in CloseWait must be accepted")
// Bare ACK keepalive from peer in CloseWait must be accepted.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
"bare ACK in CloseWait must be accepted")
}

View File

@@ -17,6 +17,9 @@ const (
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
// EnvUDPMaxEntries caps the UDP conntrack table size.
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
)
// UDPConnTrack represents a UDP connection state
@@ -34,6 +37,7 @@ type UDPTracker struct {
cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger
}
@@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel,
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
flowLogger: flowLogger,
}
@@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
conn.UpdateCounters(direction, size)
t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn
t.mutex.Unlock()
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
}
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
@@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return true
}
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample: picks the oldest among up to evictSampleSize entries.
func (t *UDPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
@@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key)
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
}

View File

@@ -787,7 +787,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("Unknown network layer: %v", d.decoded[0])
}
return false
}
@@ -901,7 +903,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, mssClampValue)
}
return true
}
@@ -1044,11 +1048,13 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder
if fragment {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
if m.logger.Enabled(nblog.LevelTrace) {
if d.decoded[0] == layers.LayerTypeIPv4 {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
} else {
m.logger.Trace2("packet is an IPv6 fragment: src=%v dst=%v", srcIP, dstIP)
}
}
return false
}
@@ -1091,8 +1097,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1142,8 +1150,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled
if !m.routingEnabled.Load() {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP)
}
return true
}
@@ -1160,8 +1170,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
@@ -1287,7 +1299,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.decodePacket(packetData); err != nil {
m.logger.Trace1("couldn't decode packet, err: %s", err)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace1("couldn't decode packet, err: %s", err)
}
return false, false
}

View File

@@ -13,6 +13,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -97,8 +98,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return nil, fmt.Errorf("write ICMP packet: %w", err)
}
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode)
}
return conn, nil
}
@@ -121,12 +124,14 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
txBytes := f.handleEchoResponse(conn, id, v6)
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
proto := "ICMP"
if v6 {
proto = "ICMPv6"
if f.logger.Enabled(nblog.LevelTrace) {
proto := "ICMP"
if v6 {
proto = "ICMPv6"
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
}
f.logger.Trace5("forwarder: Forwarded %s echo reply %v type %v code %v (rtt=%v, raw socket)",
proto, epID(id), icmpType, icmpCode, rtt)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}
@@ -224,13 +229,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
}
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode)
}
txBytes := f.synthesizeEchoReply(id, icmpData)
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
}

View File

@@ -1,11 +1,8 @@
package forwarder
import (
"context"
"io"
"net"
"strconv"
"sync"
"github.com/google/uuid"
@@ -15,7 +12,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/util/netrelay"
)
// handleTCP is called by the TCP forwarder for new connections.
@@ -37,7 +36,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
}
return
}
@@ -60,64 +61,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
}
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
// netrelay.Relay copies bidirectionally with proper half-close propagation
// and fully closes both conns before returning.
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
Logger: f.logger,
})
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
// Close the netstack endpoint after both conns are drained.
ep.Close()
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
@@ -126,7 +85,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
}
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}

View File

@@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
}
@@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
}
return true
}
@@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.Unlock()
success = true
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
}
go f.proxyUDP(connCtx, pConn, id, ep)
return true
@@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
}
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)

View File

@@ -53,16 +53,17 @@ var levelStrings = map[Level]string{
}
type logMessage struct {
level Level
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
level Level
argCount uint8
format string
arg1 any
arg2 any
arg3 any
arg4 any
arg5 any
arg6 any
arg7 any
arg8 any
}
// Logger is a high-performance, non-blocking logger
@@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
// Enabled reports whether the given level is currently logged. Callers on the
// hot path should guard log sites with this to avoid boxing arguments into
// any when the level is off.
func (l *Logger) Enabled(level Level) bool {
return l.level.Load() >= uint32(level)
}
func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) {
select {
@@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) {
func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) {
func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) {
select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) {
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
// to avoid allocating an args slice on the fast path when the arg count is
// known (0-3). Args beyond 3 land on the general variadic path; callers on
// the hot path should prefer DebugN for known counts.
func (l *Logger) Debugf(format string, args ...any) {
if l.level.Load() < uint32(LevelDebug) {
return
}
switch len(args) {
case 0:
l.Debug(format)
case 1:
l.Debug1(format, args[0])
case 2:
l.Debug2(format, args[0], args[1])
case 3:
l.Debug3(format, args[0], args[1], args[2])
default:
l.sendVariadic(LevelDebug, format, args)
}
}
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
// beyond the 8-arg slot limit are dropped so callers don't produce silently
// empty log lines via uint8 wraparound in argCount.
func (l *Logger) sendVariadic(level Level, format string, args []any) {
const maxArgs = 8
n := len(args)
if n > maxArgs {
n = maxArgs
}
msg := logMessage{level: level, argCount: uint8(n), format: format}
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
for i := 0; i < n; i++ {
*slots[i] = args[i]
}
select {
case l.msgChannel <- msg:
default:
}
}
func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
default:
}
}
@@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
@@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
@@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default:
}
}
@@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default:
}
}
@@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default:
}
}
@@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
@@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ')
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
}
}
}
var formatted string
switch argCount {
switch msg.argCount {
case 0:
formatted = msg.format
case 1:

View File

@@ -11,6 +11,7 @@ import (
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
var (
@@ -262,11 +263,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, translatedIP, false); err != nil {
m.logger.Error1("failed to rewrite packet destination: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet destination: %v", err)
}
return false
}
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
}
return true
}
@@ -283,11 +288,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
}
if err := m.rewritePacketIP(packetData, d, originalIP, true); err != nil {
m.logger.Error1("failed to rewrite packet source: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet source: %v", err)
}
return false
}
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
return true
}
@@ -612,7 +621,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite port: %v", err)
}
return false
}
d.dnatOrigPort = rule.origPort

View File

@@ -260,15 +260,23 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -299,11 +307,16 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
; Remove autostart entries from every view a previous installer may have used.
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
dnsAddresses []netip.AddrPort,
stateFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
HostDNSAddresses: dnsAddresses,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, "")

View File

@@ -45,8 +45,11 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client.
routes.txt: Detailed system routing table in tabular format including destination, gateway, interface, metrics, and protocol information, if --system-info flag was provided.
interfaces.txt: Anonymized network interface information, if --system-info flag was provided.
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
iptables.txt: Anonymized iptables (IPv4) rules with packet counters, if --system-info flag was provided.
ip6tables.txt: Anonymized ip6tables (IPv6) rules with packet counters, if --system-info flag was provided.
ipset.txt: Anonymized ipset list output, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters across all families (ip, ip6, inet, etc.), if --system-info flag was provided.
sysctls.txt: Forwarding, reverse-path filter, source-validation, and conntrack accounting sysctl values that the NetBird client may read or modify, if --system-info flag was provided (Linux only).
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
@@ -165,22 +168,33 @@ The config.txt file contains anonymized configuration information of the NetBird
Other non-sensitive configuration options are included without anonymization.
Firewall Rules (Linux only)
The bundle includes two separate firewall rule files:
The bundle includes the following firewall-related files:
iptables.txt:
- Complete iptables ruleset with packet counters using 'iptables -v -n -L'
- IPv4 iptables ruleset with packet counters using 'iptables-save' and 'iptables -v -n -L'
- Includes all tables (filter, nat, mangle, raw, security)
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
ip6tables.txt:
- IPv6 ip6tables ruleset with packet counters using 'ip6tables-save' and 'ip6tables -v -n -L'
- Same table coverage and anonymization as iptables.txt
- Omitted when ip6tables is not installed or no IPv6 rules are present
ipset.txt:
- Output of 'ipset list' (family-agnostic)
- IP addresses are anonymized; set names and types remain unchanged
nftables.txt:
- Complete nftables ruleset obtained via 'nft -a list ruleset'
- Complete nftables ruleset across all families (ip, ip6, inet, arp, bridge, netdev) via 'nft -a list ruleset'
- Includes rule handle numbers and packet counters
- All tables, chains, and rules are included
- Shows packet and byte counters for each rule
- All IP addresses are anonymized
- Chain names, table names, and other non-sensitive information remain unchanged
- All IP addresses are anonymized; chain/table names remain unchanged
sysctls.txt:
- Forwarding (IPv4 + IPv6, global and per-interface), reverse-path filter, source-validation, conntrack accounting, and TCP-related sysctls that netbird may read or modify
- Per-interface keys are enumerated from /proc/sys/net/ipv{4,6}/conf
- Interface names anonymized when --anonymize is set
IP Rules (Linux only)
The ip_rules.txt file contains detailed IP routing rule information:
@@ -412,6 +426,10 @@ func (g *BundleGenerator) addSystemInfo() {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addSysctls(); err != nil {
log.Errorf("failed to add sysctls to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}

View File

@@ -124,15 +124,18 @@ func getSystemdLogs(serviceName string) (string, error) {
// addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules")
iptablesRules, err := collectIPTablesRules()
g.addIPTablesRulesToBundle("iptables-save", "iptables", "iptables.txt")
g.addIPTablesRulesToBundle("ip6tables-save", "ip6tables", "ip6tables.txt")
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err)
log.Warnf("Failed to collect ipset information: %v", err)
} else {
if g.anonymize {
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
ipsetOutput = g.anonymizer.AnonymizeString(ipsetOutput)
}
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err)
if err := g.addFileToZip(strings.NewReader(ipsetOutput), "ipset.txt"); err != nil {
log.Warnf("Failed to add ipset output to bundle: %v", err)
}
}
@@ -151,44 +154,65 @@ func (g *BundleGenerator) addFirewallRules() error {
return nil
}
// collectIPTablesRules collects rules using both iptables-save and verbose listing
func collectIPTablesRules() (string, error) {
var builder strings.Builder
saveOutput, err := collectIPTablesSave()
// addIPTablesRulesToBundle collects iptables/ip6tables rules and writes them to the bundle.
func (g *BundleGenerator) addIPTablesRulesToBundle(saveBin, listBin, filename string) {
rules, err := collectIPTablesRules(saveBin, listBin)
if err != nil {
log.Warnf("Failed to collect iptables rules using iptables-save: %v", err)
} else {
builder.WriteString("=== iptables-save output ===\n")
log.Warnf("Failed to collect %s rules: %v", listBin, err)
return
}
if g.anonymize {
rules = g.anonymizer.AnonymizeString(rules)
}
if err := g.addFileToZip(strings.NewReader(rules), filename); err != nil {
log.Warnf("Failed to add %s rules to bundle: %v", listBin, err)
}
}
// collectIPTablesRules collects rules using both <saveBin> and verbose listing via <listBin>.
// Returns an error when neither command produced any output (e.g. the binary is missing),
// so the caller can skip writing an empty file.
func collectIPTablesRules(saveBin, listBin string) (string, error) {
var builder strings.Builder
var collected bool
var firstErr error
saveOutput, err := runCommand(saveBin)
switch {
case err != nil:
firstErr = err
log.Warnf("Failed to collect %s output: %v", saveBin, err)
case strings.TrimSpace(saveOutput) == "":
log.Debugf("%s produced no output, skipping", saveBin)
default:
builder.WriteString(fmt.Sprintf("=== %s output ===\n", saveBin))
builder.WriteString(saveOutput)
builder.WriteString("\n")
collected = true
}
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n")
listHeader := fmt.Sprintf("=== %s -v -n -L output ===\n", listBin)
builder.WriteString(listHeader)
tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table))
stats, err := getTableStatistics(table)
stats, err := runCommand(listBin, "-v", "-n", "-L", "-t", table)
if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err)
if firstErr == nil {
firstErr = err
}
log.Warnf("Failed to get %s statistics for table %s: %v", listBin, table, err)
continue
}
builder.WriteString(fmt.Sprintf("*%s\n", table))
builder.WriteString(stats)
builder.WriteString("\n")
collected = true
}
if !collected {
return "", fmt.Errorf("collect %s rules: %w", listBin, firstErr)
}
return builder.String(), nil
}
@@ -214,34 +238,15 @@ func collectIPSets() (string, error) {
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save")
// runCommand executes a command and returns its stdout, wrapping stderr in the error on failure.
func runCommand(name string, args ...string) (string, error) {
cmd := exec.Command(name, args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String())
}
rules := stdout.String()
if strings.TrimSpace(rules) == "" {
return "", fmt.Errorf("no iptables rules found")
}
return rules, nil
}
// getTableStatistics gets verbose statistics for an entire table using iptables command
func getTableStatistics(table string) (string, error) {
cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String())
return "", fmt.Errorf("execute %s: %w (stderr: %s)", name, err, stderr.String())
}
return stdout.String(), nil
@@ -804,3 +809,91 @@ func formatSetKeyType(keyType nftables.SetDatatype) string {
return fmt.Sprintf("type-%v", keyType)
}
}
// addSysctls collects forwarding and netbird-managed sysctl values and writes them to the bundle.
func (g *BundleGenerator) addSysctls() error {
log.Info("Collecting sysctls")
content := collectSysctls()
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "sysctls.txt"); err != nil {
return fmt.Errorf("add sysctls to bundle: %w", err)
}
return nil
}
// collectSysctls reads every sysctl that the netbird client may modify, plus
// global IPv4/IPv6 forwarding, and returns a formatted dump grouped by topic.
// Per-interface values are enumerated by listing /proc/sys/net/ipv{4,6}/conf.
func collectSysctls() string {
var builder strings.Builder
writeSysctlGroup(&builder, "forwarding", []string{
"net.ipv4.ip_forward",
"net.ipv6.conf.all.forwarding",
"net.ipv6.conf.default.forwarding",
})
writeSysctlGroup(&builder, "ipv4 per-interface forwarding", listInterfaceSysctls("ipv4", "forwarding"))
writeSysctlGroup(&builder, "ipv6 per-interface forwarding", listInterfaceSysctls("ipv6", "forwarding"))
writeSysctlGroup(&builder, "rp_filter", append(
[]string{"net.ipv4.conf.all.rp_filter", "net.ipv4.conf.default.rp_filter"},
listInterfaceSysctls("ipv4", "rp_filter")...,
))
writeSysctlGroup(&builder, "src_valid_mark", append(
[]string{"net.ipv4.conf.all.src_valid_mark", "net.ipv4.conf.default.src_valid_mark"},
listInterfaceSysctls("ipv4", "src_valid_mark")...,
))
writeSysctlGroup(&builder, "conntrack", []string{
"net.netfilter.nf_conntrack_acct",
"net.netfilter.nf_conntrack_tcp_loose",
})
writeSysctlGroup(&builder, "tcp", []string{
"net.ipv4.tcp_tw_reuse",
})
return builder.String()
}
func writeSysctlGroup(builder *strings.Builder, title string, keys []string) {
builder.WriteString(fmt.Sprintf("=== %s ===\n", title))
for _, key := range keys {
value, err := readSysctl(key)
if err != nil {
builder.WriteString(fmt.Sprintf("%s = <error: %v>\n", key, err))
continue
}
builder.WriteString(fmt.Sprintf("%s = %s\n", key, value))
}
builder.WriteString("\n")
}
// listInterfaceSysctls returns net.ipvX.conf.<iface>.<leaf> keys for every
// interface present in /proc/sys/net/ipvX/conf, skipping "all" and "default"
// (callers add those explicitly so they appear first).
func listInterfaceSysctls(family, leaf string) []string {
dir := fmt.Sprintf("/proc/sys/net/%s/conf", family)
entries, err := os.ReadDir(dir)
if err != nil {
return nil
}
var keys []string
for _, e := range entries {
name := e.Name()
if name == "all" || name == "default" {
continue
}
keys = append(keys, fmt.Sprintf("net.%s.conf.%s.%s", family, name, leaf))
}
sort.Strings(keys)
return keys
}
func readSysctl(key string) (string, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
value, err := os.ReadFile(path)
if err != nil {
return "", err
}
return strings.TrimSpace(string(value)), nil
}

View File

@@ -17,3 +17,8 @@ func (g *BundleGenerator) addIPRules() error {
// IP rules are only supported on Linux
return nil
}
func (g *BundleGenerator) addSysctls() error {
// Sysctl collection is only supported on Linux
return nil
}

View File

@@ -16,6 +16,10 @@ type hostManager interface {
restoreHostDNS() error
supportCustomPort() bool
string() string
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
// hardcoded Quad9 on iOS, nil for noop / mock.
getOriginalNameservers() []netip.Addr
}
type SystemDNSSettings struct {
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
func (n noopHostConfigurator) string() string {
return "noop"
}
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
return nil
}

View File

@@ -1,14 +1,20 @@
package dns
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// androidHostManager is a noop on the OS side (Android's VPN service handles
// DNS for us) but tracks the OS-reported resolver list pushed via
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
type androidHostManager struct {
holder *hostsDNSHolder
}
func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
return &androidHostManager{holder: holder}, nil
}
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
func (a androidHostManager) string() string {
return "none"
}
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
hosts := a.holder.get()
out := make([]netip.Addr, 0, len(hosts))
for ap := range hosts {
out = append(out, ap.Addr())
}
return out
}

View File

@@ -3,6 +3,7 @@ package dns
import (
"encoding/json"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
}, nil
}
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
return []netip.Addr{
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
}
}
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strings"
"syscall"
"time"
@@ -44,9 +45,11 @@ const (
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList"
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer"
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
@@ -67,10 +70,11 @@ const (
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
nrptEntryCount int
guid string
routingAll bool
gpo bool
nrptEntryCount int
origNameservers []netip.Addr
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
gpo: useGPO,
}
origNameservers, err := configurator.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
}
configurator.origNameservers = origNameservers
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
return configurator, nil
}
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
// registry key except the WG adapter. v4 and v6 servers live in separate
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
var merr *multierror.Error
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
addrs, err := r.captureFromTcpipRoot(root)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
continue
}
for _, addr := range addrs {
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
if err != nil {
return nil, fmt.Errorf("open key: %w", err)
}
defer closer(root)
guids, err := root.ReadSubKeyNames(-1)
if err != nil {
return nil, fmt.Errorf("read subkeys: %w", err)
}
var out []netip.Addr
for _, guid := range guids {
if strings.EqualFold(guid, r.guid) {
continue
}
out = append(out, readInterfaceNameservers(rootPath, guid)...)
}
return out, nil
}
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
keyPath := rootPath + "\\" + guid
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
if err != nil {
return nil
}
defer closer(k)
// Static NameServer wins over DhcpNameServer for actual resolution.
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
raw, _, err := k.GetStringValue(name)
if err != nil || raw == "" {
continue
}
if out := parseRegistryNameservers(raw); len(out) > 0 {
return out
}
}
return nil
}
func parseRegistryNameservers(raw string) []netip.Addr {
var out []netip.Addr
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
addr, err := netip.ParseAddr(strings.TrimSpace(field))
if err != nil {
continue
}
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// Drop unzoned link-local: not routable without a scope id. If
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
continue
}
out = append(out, addr)
}
return out
}
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(r.origNameservers)
}
func (r *registryConfigurator) supportCustomPort() bool {
return false
}

View File

@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
h.mutex.Unlock()
}
//nolint:unused
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
h.mutex.RLock()
l := h.unprotectedDNSList

View File

@@ -76,8 +76,6 @@ func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability(context.Context) {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
logger := log.WithFields(log.Fields{

View File

@@ -9,6 +9,7 @@ import (
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
return make([]string, 0)
}
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
func (m *MockServer) ProbeAvailability() {
}
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
if m.UpdateServerConfigFunc != nil {
return m.UpdateServerConfigFunc(domains)
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil
}
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
// SetRouteSources mock implementation of SetRouteSources from Server interface
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
// Mock implementation - no-op
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"net/netip"
"slices"
"strings"
"time"
@@ -32,6 +33,15 @@ const (
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
networkManagerDbusIPv4Key = "ipv4"
networkManagerDbusIPv6Key = "ipv6"
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
}
type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
dbusLinkObject dbus.ObjectPath
routingAll bool
ifaceName string
origNameservers []netip.Addr
}
// the types below are based on dbus specification, each field is mapped to a dbus type
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
return &networkManagerDbusConfigurator{
c := &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from NetworkManager: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads DNS servers from every NM device's
// IP4Config / IP6Config except our WG device.
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
devices, err := networkManagerListDevices()
if err != nil {
return nil, fmt.Errorf("list devices: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, dev := range devices {
if dev == n.dbusLinkObject {
continue
}
ifaceName := readNetworkManagerDeviceInterface(dev)
for _, addr := range readNetworkManagerDeviceDNS(dev) {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
continue
}
// IP6Config.Nameservers is a byte slice without zone info;
// reattach the device's interface name so a captured fe80::…
// stays routable.
if addr.IsLinkLocalUnicast() && ifaceName != "" {
addr = addr.WithZone(ifaceName)
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return ""
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
if err != nil {
return ""
}
s, _ := v.Value().(string)
return s
}
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil {
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
}
defer closeConn()
var devs []dbus.ObjectPath
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
return nil, err
}
return devs, nil
}
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
if err != nil {
return nil
}
defer closeConn()
var out []netip.Addr
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
out = append(out, readIPv4ConfigDNS(path)...)
}
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
out = append(out, readIPv6ConfigDNS(path)...)
}
return out
}
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
v, err := obj.GetProperty(property)
if err != nil {
return ""
}
path, ok := v.Value().(dbus.ObjectPath)
if !ok || path == "/" {
return ""
}
return path
}
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
// legacy uint32 Nameservers property.
if out := readIPv4NameserverData(obj); len(out) > 0 {
return out
}
return readIPv4LegacyNameservers(obj)
}
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([]map[string]dbus.Variant)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
addrVar, ok := entry["address"]
if !ok {
continue
}
s, ok := addrVar.Value().(string)
if !ok {
continue
}
if a, err := netip.ParseAddr(s); err == nil {
out = append(out, a)
}
}
return out
}
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([]uint32)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, n := range raw {
var b [4]byte
binary.LittleEndian.PutUint32(b[:], n)
out = append(out, netip.AddrFrom4(b))
}
return out
}
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(networkManagerDest, path)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
if err != nil {
return nil
}
raw, ok := v.Value().([][]byte)
if !ok {
return nil
}
out := make([]netip.Addr, 0, len(raw))
for _, b := range raw {
if a, ok := netip.AddrFromSlice(b); ok {
out = append(out, a)
}
}
return out
}
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(n.origNameservers)
}
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager()
return newHostManager(s.hostsDNSHolder)
}

View File

@@ -6,7 +6,7 @@ import (
"net"
"net/netip"
"os"
"strings"
"runtime"
"testing"
"time"
@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -31,8 +32,10 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/proto"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
@@ -101,16 +104,17 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
return &upstreamResolverBase{
domain: domain,
upstreamServers: srvs,
cancel: func() {},
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) {
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
}
}
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{}
server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}),
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
{false, "domain1", false},
{false, "domain2", false},
},
},
statusRecorder: peer.NewRecorder("mgm"),
}
var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{}
for _, item := range config.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
domainsUpdate = strings.Join(domains, ",")
return nil
}
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
Domains: []string{"domain1"},
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
},
}, nil, 0)
deactivate(nil)
expected := "domain0,domain2"
domains := []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got := strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, got)
}
reactivate()
expected = "domain0,domain1,domain2"
domains = []string{}
for _, item := range server.currentConfig.Domains {
if item.Disabled {
continue
}
domains = append(domains, item.Domain)
}
got = strings.Join(domains, ",")
if expected != got {
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
}
}
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
}
func TestDNSPermanent_updateUpstream(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
func TestDNSPermanent_matchOnly(t *testing.T) {
skipUnlessAndroid(t)
wgIFace, err := createWgInterfaceWithBind(t)
if err != nil {
t.Fatal("failed to initialize wg interface")
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
}
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
// + androidHostManager). On non-android the desktop host manager replaces it
// during Initialize and the assertion stops making sense. Skipped here until we
// have an android CI runner.
func skipUnlessAndroid(t *testing.T) {
t.Helper()
if runtime.GOOS != "android" {
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
}
}
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
t.Helper()
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
@@ -1065,7 +1017,6 @@ type mockHandler struct {
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (m *mockHandler) Stop() {}
func (m *mockHandler) ProbeAvailability(context.Context) {}
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
type mockService struct{}
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
// admin-defined nameserver groups targeting the same domain collapse into a
// single handler with each group preserved as a sequential inner list.
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
groups := []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
{
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
}
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
require.NoError(t, err)
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
assert.Equal(t, "example.com", muxUpdates[0].domain)
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
handler := muxUpdates[0].handler.(*upstreamResolver)
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
assert.Equal(t, upstreamRace{
netip.MustParseAddrPort("192.0.2.2:53"),
netip.MustParseAddrPort("192.0.2.3:53"),
}, handler.upstreamServers[1])
}
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
// (overlay route selected-but-no-active-peer) is intentionally NOT an
// input to the evaluator anymore: the verdict drives the Enabled flag,
// which must always reflect what we actually observed. Gate-aware event
// suppression is tested separately in the projection test.
//
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
// fresh-working → Unhealthy; otherwise Undecided.
func TestEvaluateNSGroupHealth(t *testing.T) {
now := time.Now()
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
okThenFail := UpstreamHealth{
LastOk: now.Add(-10 * time.Second),
LastFail: now.Add(-1 * time.Second),
LastErr: "timeout",
}
failThenOk := UpstreamHealth{
LastOk: now.Add(-1 * time.Second),
LastFail: now.Add(-10 * time.Second),
LastErr: "timeout",
}
tests := []struct {
name string
health map[netip.AddrPort]UpstreamHealth
servers []netip.AddrPort
wantVerdict nsGroupVerdict
wantErrSubst string
}{
{
name: "no record, undecided",
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "fresh success, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "fresh failure, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "only stale success, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "only stale failure, undecided",
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUndecided,
},
{
name: "both fresh, fail newer, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "both fresh, ok newer, healthy",
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
servers: []netip.AddrPort{a},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one success wins",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
b: recentOk,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictHealthy,
},
{
name: "two upstreams, one fail one unseen, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: recentFail,
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "timeout",
},
{
name: "two upstreams, all recent failures, unhealthy",
health: map[netip.AddrPort]UpstreamHealth{
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
},
servers: []netip.AddrPort{a, b},
wantVerdict: nsVerdictUnhealthy,
wantErrSubst: "SERVFAIL",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
if tc.wantErrSubst != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.wantErrSubst)
} else {
assert.NoError(t, err)
}
})
}
}
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers.
type healthStubHandler struct {
health map[netip.AddrPort]UpstreamHealth
}
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
func (h *healthStubHandler) Stop() {}
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
return h.health
}
// TestProjection_SteadyStateIsSilent guards against duplicate events:
// while a group stays Unhealthy tick after tick, only the first
// Unhealthy transition may emit. Same for staying Healthy.
func TestProjection_SteadyStateIsSilent(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "first fail emits warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.tick()
fx.expectNoEvent("staying unhealthy must not re-emit")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery on transition")
fx.tick()
fx.tick()
fx.expectNoEvent("staying healthy must not re-emit")
}
// projTestFixture is the common setup for the projection tests: a
// single-upstream group whose route classification the test can flip by
// assigning to selected/active. Callers drive failures/successes by
// mutating stub.health and calling refreshHealth.
type projTestFixture struct {
t *testing.T
recorder *peer.Status
events <-chan *proto.SystemEvent
server *DefaultServer
stub *healthStubHandler
group *nbdns.NameServerGroup
srv netip.AddrPort
selected route.HAMap
active route.HAMap
}
func newProjTestFixture(t *testing.T) *projTestFixture {
t.Helper()
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
srv := netip.MustParseAddrPort("100.64.0.1:53")
fx := &projTestFixture{
t: t,
recorder: recorder,
events: sub.Events(),
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
srv: srv,
group: &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
},
}
fx.server = &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase,
}
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
fx.server.mux.Unlock()
return fx
}
func (f *projTestFixture) setHealth(h UpstreamHealth) {
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
}
func (f *projTestFixture) tick() []peer.NSGroupState {
f.server.refreshHealth()
return f.recorder.GetDNSStates()
}
func (f *projTestFixture) expectNoEvent(why string) {
f.t.Helper()
select {
case evt := <-f.events:
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
case <-time.After(100 * time.Millisecond):
}
}
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
f.t.Helper()
select {
case evt := <-f.events:
assert.Contains(f.t, evt.Message, substr, why)
return evt
case <-time.After(time.Second):
f.t.Fatalf("expected event (%s) with %q", why, substr)
return nil
}
}
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
// that is not inside any selected route (public DNS) fires the warning
// on the first Unhealthy tick, no grace period.
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "public DNS failure")
}
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
// the upstream is inside a selected route AND the route has a Connected
// peer. Tunnel is up, failure is real, emit immediately.
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled)
fx.expectEvent("unreachable", "overlay + connected failure")
}
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
// First tick: Unhealthy display, no warning. After the grace window
// elapses with no recovery, the warning fires.
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
grace := 50 * time.Millisecond
fx := newProjTestFixture(t)
fx.server.warningDelayBase = grace
fx.selected = overlayMapForTest
// active stays nil: routed but not connected.
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
states := fx.tick()
require.Len(t, states, 1)
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
fx.expectNoEvent("first fail tick within grace window")
time.Sleep(grace + 10*time.Millisecond)
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "warning after grace window")
}
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
// whose address is inside the WireGuard overlay range but is not
// covered by any selected route (peer-to-peer DNS without an explicit
// route). Until a peer reports Connected for that address, startup
// failures must be held just like the routed case.
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-sub.Events():
t.Fatalf("unexpected event during grace window: %+v", evt)
case <-time.After(100 * time.Millisecond):
}
time.Sleep(60 * time.Millisecond)
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
server.refreshHealth()
select {
case evt := <-sub.Events():
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected warning after grace window")
}
}
// TestProjection_StopClearsHealthState verifies that Stop wipes the
// per-group projection state so a subsequent Start doesn't inherit
// sticky flags (notably everHealthy) that would bypass the grace
// window during the next peer handshake.
func TestProjection_StopClearsHealthState(t *testing.T) {
wgIface := &mocWGIface{}
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgIface,
service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: defaultWarningDelayBase,
currentConfigHash: ^uint64(0),
}
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
srv := netip.MustParseAddrPort("8.8.8.8:53")
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
}
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
server.healthProjectMu.Lock()
p, ok := server.nsGroupProj[generateGroupKey(group)]
server.healthProjectMu.Unlock()
require.True(t, ok, "projection state should exist after tick")
require.True(t, p.everHealthy, "tick with success must set everHealthy")
server.Stop()
server.healthProjectMu.Lock()
cleared := server.nsGroupProj == nil
server.healthProjectMu.Unlock()
assert.True(t, cleared, "Stop must clear nsGroupProj")
}
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
// rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either.
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond
fx.selected = overlayMapForTest
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectNoEvent("fail within grace, warning suppressed")
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("recovery without prior warning must not emit")
}
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
// whole design leans on: recovery events only appear when a warning
// event was actually emitted for the current streak. A Healthy verdict
// without a prior warning is silent, so the user never sees "recovered"
// out of thin air.
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
fx := newProjTestFixture(t)
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
states := fx.tick()
require.Len(t, states, 1)
assert.True(t, states[0].Enabled)
fx.expectNoEvent("first healthy tick should not recover anything")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "public fail emits immediately")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "recovery follows real warning")
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "second cycle warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "second cycle recovery")
}
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
// has ever been Healthy, subsequent failures skip the grace window even
// if classification says "routed + not connected". The system has
// proved it can work, so any new failure is real.
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
fx := newProjTestFixture(t)
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
fx.server.warningDelayBase = time.Hour
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
// Establish "ever healthy".
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectNoEvent("first healthy tick")
// Peer drops. Query fails. Routed + not connected → normally grace,
// but everHealthy flag bypasses it.
fx.active = nil
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
}
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
// from the design discussion: once a group has been healthy, a brief
// reconnect that produces a failing tick will fire warning + recovery.
// This is by design: user-visible blips are accurate signal, not noise.
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
fx := newProjTestFixture(t)
fx.selected = overlayMapForTest
fx.active = overlayMapForTest
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
fx.tick()
fx.expectEvent("unreachable", "blip warning")
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
fx.tick()
fx.expectEvent("recovered", "blip recovery")
}
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
// rule: a group with at least one public upstream is in the "immediate"
// category regardless of the other upstreams' routing, because the
// public one has no peer-startup excuse. Prevents public-DNS failures
// from being hidden behind a routed sibling.
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
recorder := peer.NewRecorder("mgm")
sub := recorder.SubscribeToEvents()
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
events := sub.Events()
public := netip.MustParseAddrPort("8.8.8.8:53")
overlay := netip.MustParseAddrPort("100.64.0.1:53")
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
server := &DefaultServer{
ctx: context.Background(),
statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour,
}
group := &nbdns.NameServerGroup{
Domains: []string{"example.com"},
NameServers: []nbdns.NameServer{
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
},
}
stub := &healthStubHandler{
health: map[netip.AddrPort]UpstreamHealth{
public: {LastFail: time.Now(), LastErr: "servfail"},
overlay: {LastFail: time.Now(), LastErr: "timeout"},
},
}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
server.mux.Unlock()
server.refreshHealth()
select {
case evt := <-events:
assert.Contains(t, evt.Message, "unreachable")
case <-time.After(time.Second):
t.Fatal("expected immediate warning because group contains a public upstream")
}
}
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
flat := handler.flatUpstreams()
assert.Len(t, flat, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
for _, upstream := range flat {
if upstream.Addr() == expected {
found = true
break

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"
"github.com/godbus/dbus/v5"
@@ -40,10 +41,17 @@ const (
)
type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath
ifaceName string
dbusLinkObject dbus.ObjectPath
ifaceName string
wgIndex int
origNameservers []netip.Addr
}
const (
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
)
// the types below are based on dbus specification, each field is mapped to a dbus type
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
return &systemdDbusConfigurator{
c := &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil
wgIndex: iface.Index,
}
origNameservers, err := c.captureOriginalNameservers()
switch {
case err != nil:
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
case len(origNameservers) == 0:
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
default:
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
}
c.origNameservers = origNameservers
return c, nil
}
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
// every default-route link except our own WG link. Non-default-route links
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
// actually serve host queries.
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("list interfaces: %w", err)
}
seen := make(map[netip.Addr]struct{})
var out []netip.Addr
for _, iface := range ifaces {
if !s.isCandidateLink(iface) {
continue
}
linkPath, err := getSystemdLinkPath(iface.Index)
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
continue
}
for _, addr := range readSystemdLinkDNS(linkPath) {
addr = normalizeSystemdAddr(addr, iface.Name)
if !addr.IsValid() {
continue
}
if _, dup := seen[addr]; dup {
continue
}
seen[addr] = struct{}{}
out = append(out, addr)
}
}
return out, nil
}
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
if iface.Index == s.wgIndex {
return false
}
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
return false
}
return true
}
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
// Returns the zero Addr to signal "skip this entry".
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
addr = addr.Unmap()
if !addr.IsValid() || addr.IsUnspecified() {
return netip.Addr{}
}
if addr.IsLinkLocalUnicast() {
return addr.WithZone(ifaceName)
}
return addr
}
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
if err != nil {
return "", fmt.Errorf("dbus resolve1: %w", err)
}
defer closeConn()
var p string
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
return "", err
}
return dbus.ObjectPath(p), nil
}
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return false
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
if err != nil {
return false
}
b, ok := v.Value().(bool)
return ok && b
}
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
if err != nil {
return nil
}
defer closeConn()
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
if err != nil {
return nil
}
entries, ok := v.Value().([][]any)
if !ok {
return nil
}
var out []netip.Addr
for _, entry := range entries {
if len(entry) < 2 {
continue
}
raw, ok := entry[1].([]byte)
if !ok {
continue
}
addr, ok := netip.AddrFromSlice(raw)
if !ok {
continue
}
out = append(out, addr)
}
return out
}
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
return slices.Clone(s.origNameservers)
}
func (s *systemdDbusConfigurator) supportCustomPort() bool {

View File

@@ -1,3 +1,32 @@
// Package dns implements the client-side DNS stack: listener/service on the
// peer's tunnel address, handler chain that routes questions by domain and
// priority, and upstream resolvers that forward what remains to configured
// nameservers.
//
// # Upstream resolution and the race model
//
// When two or more nameserver groups target the same domain, DefaultServer
// merges them into one upstream handler whose state is:
//
// upstreamResolverBase
// └── upstreamServers []upstreamRace // one entry per source NS group
// └── []netip.AddrPort // primary, fallback, ...
//
// Each source nameserver group contributes one upstreamRace. Within a race
// upstreams are tried in order: the next is used only on failure (timeout,
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
// the walk. When more than one race exists, ServeDNS fans out one
// goroutine per race and returns the first valid answer, cancelling the
// rest. A handler with a single race skips the fan-out.
//
// # Health projection
//
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
// periodically merges these snapshots across handlers and projects them
// into peer.NSGroupState. There is no active probing: a group is marked
// unhealthy only when every seen upstream has a recent failure and none
// has a recent success. Healthy→unhealthy fires a single
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
package dns
import (
@@ -11,11 +40,8 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -25,11 +51,33 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
var currentMTU uint16 = iface.DefaultMTU
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
// from one upstream means another upstream would return the same answer:
// DNSSEC validation outcomes and policy-based blocks. Transient errors
// (network, cached, not ready) are not included.
var nonRetryableEDECodes = map[uint16]struct{}{
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
dns.ExtendedErrorCodeDNSBogus: {},
dns.ExtendedErrorCodeSignatureExpired: {},
dns.ExtendedErrorCodeSignatureNotYetValid: {},
dns.ExtendedErrorCodeDNSKEYMissing: {},
dns.ExtendedErrorCodeRRSIGsMissing: {},
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
dns.ExtendedErrorCodeNSECMissing: {},
dns.ExtendedErrorCodeBlocked: {},
dns.ExtendedErrorCodeCensored: {},
dns.ExtendedErrorCodeFiltered: {},
dns.ExtendedErrorCodeProhibited: {},
}
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
type privateClientIface interface {
Name() string
@@ -46,15 +94,17 @@ const (
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
ClientTimeout = 5 * time.Second
reactivatePeriod = 30 * time.Second
probeTimeout = 2 * time.Second
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
// payload from the tunnel MTU.
ipUDPHeaderSize = 60 + 8
)
const testRecord = "com."
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
// within one race, so a slow primary can't eat the whole race budget.
raceMaxTotalTimeout = 5 * time.Second
// raceMinPerUpstreamTimeout is the floor applied when dividing
// raceMaxTotalTimeout across upstreams within a race.
raceMinPerUpstreamTimeout = 2 * time.Second
)
const (
protoUDP = "udp"
@@ -63,6 +113,69 @@ const (
type dnsProtocolKey struct{}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
// upstreamRace is an ordered list of upstreams derived from one configured
// nameserver group. Order matters: the first upstream is tried first, the
// second only on failure, and so on. Multiple upstreamRace values coexist
// inside one resolver when overlapping nameserver groups target the same
// domain; those races run in parallel and the first valid answer wins.
type upstreamRace []netip.AddrPort
// UpstreamHealth is the last query-path outcome for a single upstream,
// consumed by nameserver-group status projection.
type UpstreamHealth struct {
LastOk time.Time
LastFail time.Time
LastErr string
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []upstreamRace
domain domain.Domain
upstreamTimeout time.Duration
healthMu sync.RWMutex
health map[netip.AddrPort]*UpstreamHealth
statusRecorder *peer.Status
// selectedRoutes returns the current set of client routes the admin
// has enabled. Called lazily from the query hot path when an upstream
// might need a tunnel-bound client (iOS) and from health projection.
selectedRoutes func() route.HAMap
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
type raceResult struct {
msg *dns.Msg
upstream netip.AddrPort
protocol string
ede string
failures []upstreamFailure
}
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
return context.WithValue(ctx, dnsProtocolKey{}, network)
@@ -79,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
return ""
}
type upstreamProtocolKey struct{}
// upstreamProtocolResult holds the protocol used for the upstream exchange.
// Stored as a pointer in context so the exchange function can set it.
type upstreamProtocolResult struct {
protocol string
}
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
r := &upstreamProtocolResult{}
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
}
@@ -103,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
}
}
type upstreamClient interface {
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type UpstreamResolver interface {
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
}
type upstreamResolverBase struct {
ctx context.Context
cancel context.CancelFunc
upstreamClient upstreamClient
upstreamServers []netip.AddrPort
domain string
disabled bool
successCount atomic.Int32
mutex sync.Mutex
reactivatePeriod time.Duration
upstreamTimeout time.Duration
wg sync.WaitGroup
deactivate func(error)
reactivate func()
statusRecorder *peer.Status
routeMatch func(netip.Addr) bool
}
type upstreamFailure struct {
upstream netip.AddrPort
reason string
}
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
ctx, cancel := context.WithCancel(ctx)
return &upstreamResolverBase{
ctx: ctx,
cancel: cancel,
domain: domain,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: reactivatePeriod,
statusRecorder: statusRecorder,
ctx: ctx,
cancel: cancel,
domain: d,
upstreamTimeout: UpstreamTimeout,
statusRecorder: statusRecorder,
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("Upstream %s", u.upstreamServers)
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
}
// ID returns the unique handler ID
// ID returns the unique handler ID. Race groupings and within-race
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
// the same servers but with different semantics (serial fallback vs
// parallel race), so their handlers must not collide.
func (u *upstreamResolverBase) ID() types.HandlerID {
servers := slices.Clone(u.upstreamServers)
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
hash := sha256.New()
hash.Write([]byte(u.domain + ":"))
for _, s := range servers {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
hash.Write([]byte(u.domain.PunycodeString() + ":"))
for _, race := range u.upstreamServers {
hash.Write([]byte("["))
for _, s := range race {
hash.Write([]byte(s.String()))
hash.Write([]byte("|"))
}
hash.Write([]byte("]"))
}
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
}
@@ -173,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
}
func (u *upstreamResolverBase) Stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
u.cancel()
}
u.mutex.Lock()
u.wg.Wait()
u.mutex.Unlock()
// flatUpstreams is for logging and ID hashing only, not for dispatch.
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
var out []netip.AddrPort
for _, g := range u.upstreamServers {
out = append(out, g...)
}
return out
}
// setSelectedRoutes swaps the accessor used to classify overlay-routed
// upstreams. Called when route sources are wired after the handler was
// built (permanent / iOS constructors).
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
u.selectedRoutes = selected
}
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
if len(servers) == 0 {
return
}
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
}
// ServeDNS handles a DNS request
@@ -221,59 +314,226 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
}
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
timeout := u.upstreamTimeout
if len(u.upstreamServers) > 1 {
maxTotal := 5 * time.Second
minPerUpstream := 2 * time.Second
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
if scaledTimeout > minPerUpstream {
timeout = scaledTimeout
} else {
timeout = minPerUpstream
}
groups := u.upstreamServers
switch len(groups) {
case 0:
return false, nil
case 1:
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
default:
return u.raceAll(ctx, w, r, groups, logger)
}
}
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
res := u.tryRace(ctx, r, group)
if res.msg == nil {
return false, res.failures
}
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, res.failures
}
// raceAll runs one worker per group in parallel, taking the first valid
// answer and cancelling the rest.
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
raceCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Buffer sized to len(groups) so workers never block on send, even
// after the coordinator has returned.
results := make(chan raceResult, len(groups))
for _, g := range groups {
// tryRace clones the request per attempt, so workers never share
// a *dns.Msg and concurrent EDNS0 mutations can't race.
go func(g upstreamRace) {
results <- u.tryRace(raceCtx, r, g)
}(g)
}
var failures []upstreamFailure
for _, upstream := range u.upstreamServers {
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
failures = append(failures, *failure)
} else {
return true, failures
for range groups {
select {
case res := <-results:
failures = append(failures, res.failures...)
if res.msg != nil {
if res.ede != "" {
resutil.SetMeta(w, "ede", res.ede)
}
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
return true, failures
}
case <-ctx.Done():
return false, failures
}
}
return false, failures
}
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
var rm *dns.Msg
var t time.Duration
var err error
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
timeout := u.upstreamTimeout
if len(group) > 1 {
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
// still honor raceMinPerUpstreamTimeout as a floor for correctness
// on slow links, but the outer context ensures the combined walk
// cannot exceed the cap regardless of group size.
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
}
var failures []upstreamFailure
for _, upstream := range group {
if ctx.Err() != nil {
return raceResult{failures: failures}
}
// Clone the request per attempt: the exchange path mutates EDNS0
// options in-place, so reusing the same *dns.Msg across sequential
// upstreams would carry those mutations (e.g. a reduced UDP size)
// into the next attempt.
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
if failure != nil {
failures = append(failures, *failure)
continue
}
res.failures = failures
return res
}
return raceResult{failures: failures}
}
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
ctx, cancel := context.WithTimeout(parentCtx, timeout)
defer cancel()
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// The caller already passed a per-attempt copy, so we can mutate r
// directly; hadEdns reflects the original client request's state and
// controls whether we strip the OPT from the response.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(upstreamUDPSize(), false)
}
startTime := time.Now()
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
if err != nil {
return u.handleUpstreamError(err, upstream, startTime)
// A parent cancellation (e.g., another race won and the coordinator
// cancelled the losers) is not an upstream failure. Check both the
// error chain and the parent context: a transport may surface the
// cancellation as a read/deadline error rather than context.Canceled.
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
}
failure := u.handleUpstreamError(err, upstream, startTime)
u.markUpstreamFail(upstream, failure.reason)
return raceResult{}, failure
}
if rm == nil || !rm.Response {
return &upstreamFailure{upstream: upstream, reason: "no response"}
u.markUpstreamFail(upstream, "no response")
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
}
proto := ""
if upstreamProto != nil {
proto = upstreamProto.protocol
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
}
reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
if !hadEdns {
stripOPT(rm)
}
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
}
// healthEntry returns the mutable health record for addr, lazily creating
// the map and the entry. Caller must hold u.healthMu.
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
if u.health == nil {
u.health = make(map[netip.AddrPort]*UpstreamHealth)
}
h := u.health[addr]
if h == nil {
h = &UpstreamHealth{}
u.health[addr] = h
}
return h
}
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastOk = time.Now()
h.LastFail = time.Time{}
h.LastErr = ""
}
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
u.healthMu.Lock()
defer u.healthMu.Unlock()
h := u.healthEntry(addr)
h.LastFail = time.Now()
h.LastErr = reason
}
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
u.healthMu.RLock()
defer u.healthMu.RUnlock()
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
for k, v := range u.health {
out[k] = *v
}
return out
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
@@ -289,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
return &upstreamFailure{upstream: upstream, reason: reason}
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
u.successCount.Add(1)
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
}
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
resutil.SetMeta(w, "upstream", upstream.String())
if upstreamProto != nil && upstreamProto.protocol != "" {
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
if proto != "" {
resutil.SetMeta(w, "upstream_protocol", proto)
}
// Clear Zero bit from external responses to prevent upstream servers from
@@ -303,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true
}
return true
}
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
totalUpstreams := len(u.upstreamServers)
totalUpstreams := len(u.flatUpstreams())
failedCount := len(failures)
failureSummary := formatFailures(failures)
@@ -337,117 +605,32 @@ func formatFailures(failures []upstreamFailure) string {
return strings.Join(parts, ", ")
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
u.mutex.Lock()
defer u.mutex.Unlock()
// avoid probe if upstreams could resolve at least one query
if u.successCount.Load() > 0 {
return
// nonRetryableEDE returns the first non-retryable EDE code carried in the
// response, if any.
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
opt := rm.IsEdns0()
if opt == nil {
return 0, false
}
var success bool
var mu sync.Mutex
var wg sync.WaitGroup
var errs *multierror.Error
for _, upstream := range u.upstreamServers {
wg.Add(1)
go func(upstream netip.AddrPort) {
defer wg.Done()
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
if err != nil {
mu.Lock()
errs = multierror.Append(errs, err)
mu.Unlock()
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
return
}
mu.Lock()
success = true
mu.Unlock()
}(upstream)
}
wg.Wait()
select {
case <-ctx.Done():
return
case <-u.ctx.Done():
return
default:
}
// didn't find a working upstream server, let's disable and try later
if !success {
u.disable(errs.ErrorOrNil())
if u.statusRecorder == nil {
return
for _, o := range opt.Option {
ede, ok := o.(*dns.EDNS0_EDE)
if !ok {
continue
}
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
return ede.InfoCode, true
}
u.statusRecorder.PublishEvent(
proto.SystemEvent_WARNING,
proto.SystemEvent_DNS,
"All upstream servers failed (probe failed)",
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
map[string]string{"upstreams": u.upstreamServersString()},
)
}
return 0, false
}
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
func (u *upstreamResolverBase) waitUntilResponse() {
exponentialBackOff := &backoff.ExponentialBackOff{
InitialInterval: 500 * time.Millisecond,
RandomizationFactor: 0.5,
Multiplier: 1.1,
MaxInterval: u.reactivatePeriod,
MaxElapsedTime: 0,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
// edeName returns a human-readable name for an EDE code, falling back to
// the numeric code when unknown.
func edeName(code uint16) string {
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
return name
}
operation := func() error {
select {
case <-u.ctx.Done():
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
default:
}
for _, upstream := range u.upstreamServers {
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
log.Tracef("upstream check for %s: %s", upstream, err)
} else {
// at least one upstream server is available, stop probing
return nil
}
}
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
u.successCount.Add(1)
u.reactivate()
u.mutex.Lock()
u.disabled = false
u.mutex.Unlock()
return fmt.Sprintf("EDE %d", code)
}
// isTimeout returns true if the given error is a network timeout error.
@@ -461,45 +644,6 @@ func isTimeout(err error) bool {
return false
}
func (u *upstreamResolverBase) disable(err error) {
if u.disabled {
return
}
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
u.successCount.Store(0)
u.deactivate(err)
u.disabled = true
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.waitUntilResponse()
}()
}
func (u *upstreamResolverBase) upstreamServersString() string {
var servers []string
for _, server := range u.upstreamServers {
servers = append(servers, server.String())
}
return strings.Join(servers, ", ")
}
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
defer cancel()
if externalCtx != nil {
stop2 := context.AfterFunc(externalCtx, cancel)
defer stop2()
}
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
return err
}
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
func clientUDPMaxSize(r *dns.Msg) int {
if opt := r.IsEdns0(); opt != nil {
@@ -511,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
// ExchangeWithFallback exchanges a DNS message with the upstream server.
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
// If the inbound request came over TCP (via context), it skips the UDP attempt.
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
// If the request came in over TCP, go straight to TCP upstream.
if dnsProtocolFromContext(ctx) == protoTCP {
tcpClient := *client
tcpClient.Net = protoTCP
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -537,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
opt.SetUDPSize(maxUDPPayload)
}
var (
rm *dns.Msg
t time.Duration
err error
)
if ctx == nil {
rm, t, err = client.Exchange(r, upstream)
} else {
rm, t, err = client.ExchangeContext(ctx, r, upstream)
}
rm, t, err := client.ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with udp: %w", err)
}
@@ -562,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
// data than the client's buffer, we could truncate locally and skip
// the TCP retry.
tcpClient := *client
tcpClient.Net = protoTCP
if ctx == nil {
rm, t, err = tcpClient.Exchange(r, upstream)
} else {
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
}
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
if err != nil {
return nil, t, fmt.Errorf("with tcp: %w", err)
}
@@ -584,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
// so net.Dialer doesn't reject the TCP dial with "mismatched local
// address type".
func toTCPClient(c *dns.Client) *dns.Client {
tcp := *c
tcp.Net = protoTCP
if tcp.Dialer == nil {
return &tcp
}
d := *tcp.Dialer
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
}
tcp.Dialer = &d
return &tcp
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
@@ -725,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
return bestMatch
}
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
if u.statusRecorder == nil {
return ""
// haMapRouteCount returns the total number of routes across all HA
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
// routes per key, so len(hm) is the number of HA groups, not routes.
func haMapRouteCount(hm route.HAMap) int {
total := 0
for _, routes := range hm {
total += len(routes)
}
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
if peerInfo == nil {
return ""
}
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
return total
}
// haMapContains checks whether ip is covered by any concrete prefix in
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
// routes carry a placeholder Network that can't be prefix-checked, so we
// can't know at this point whether ip is reached through one. Callers
// decide how to interpret the unknown: health projection treats it as
// "possibly routed" to avoid emitting false-positive warnings during
// startup, while iOS dial selection requires a concrete match before
// binding to the tunnel.
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
for _, routes := range hm {
for _, r := range routes {
if r.IsDynamic() {
haveDynamic = true
continue
}
if r.Network.Contains(ip) {
return true, haveDynamic
}
}
}
return false, haveDynamic
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -26,9 +27,9 @@ func newUpstreamResolver(
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
c := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
hostsDNSHolder: hostsDNSHolder,

View File

@@ -12,6 +12,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolver struct {
@@ -24,9 +25,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolver, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
)
type upstreamResolverIOS struct {
@@ -27,9 +28,9 @@ func newUpstreamResolver(
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
d domain.Domain,
) (*upstreamResolverIOS, error) {
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
upstreamIP = upstreamIP.Unmap()
}
addr := u.wgIface.Address()
var routed bool
if u.selectedRoutes != nil {
// Only a concrete prefix match binds to the tunnel: dialing
// through a private client for an upstream we can't prove is
// routed would break public resolvers.
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
}
needsPrivate := addr.Network.Contains(upstreamIP) ||
addr.IPv6Net.Contains(upstreamIP) ||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
routed
if needsPrivate {
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
}
}
// Cannot use client.ExchangeContext because it overwrites our Dialer
return ExchangeWithFallback(nil, client, r, upstream)
return ExchangeWithFallback(ctx, client, r, upstream)
}
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"strings"
"sync/atomic"
"testing"
"time"
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
}
}
resolver.upstreamServers = servers
resolver.addRace(servers)
resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX {
cancel()
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration
err error
}
// exchange mock implementation of exchange from upstreamResolver
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
return c.r, c.rtt, c.err
}
type mockUpstreamResponse struct {
msg *dns.Msg
err error
msg *dns.Msg
err error
delay time.Duration
}
type mockUpstreamResolverPerServer struct {
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
rtt time.Duration
}
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
if r, ok := c.responses[upstream]; ok {
return r.msg, c.rtt, r.err
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
r, ok := c.responses[upstream]
if !ok {
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
}
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
mockClient := &mockUpstreamResolver{
err: dns.ErrTime,
r: new(dns.Msg),
rtt: time.Millisecond,
}
resolver := &upstreamResolverBase{
ctx: context.TODO(),
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
reactivatePeriod: time.Microsecond * 100,
}
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
failed := false
resolver.deactivate = func(error) {
failed = true
// After deactivation, make the mock client work again
mockClient.err = nil
}
reactivated := false
resolver.reactivate = func() {
reactivated = true
}
resolver.ProbeAvailability(context.TODO())
if !failed {
t.Errorf("expected that resolving was deactivated")
return
}
if !resolver.disabled {
t.Errorf("resolver should be Disabled")
return
}
time.Sleep(time.Millisecond * 200)
if !reactivated {
t.Errorf("expected that resolving was reactivated")
return
}
if resolver.disabled {
t.Errorf("should be enabled")
if r.delay > 0 {
select {
case <-time.After(r.delay):
case <-ctx.Done():
return nil, c.rtt, ctx.Err()
}
}
return r.msg, c.rtt, r.err
}
func TestUpstreamResolver_Failover(t *testing.T) {
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: trackingClient,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamServers: []netip.AddrPort{upstream},
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{upstream})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
}
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
// configured for the same domain, with one broken group. The merge+race
// path should answer as fast as the working group and not pay the timeout
// of the broken one on every query.
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
broken := netip.MustParseAddrPort("192.0.2.1:53")
working := netip.MustParseAddrPort("192.0.2.2:53")
successAnswer := "192.0.2.100"
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
// Force the broken upstream to only unblock via timeout /
// cancellation so the assertion below can't pass if races
// were run serially.
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: 250 * time.Millisecond,
}
resolver.addRace([]netip.AddrPort{broken})
resolver.addRace([]netip.AddrPort{working})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
start := time.Now()
resolver.ServeDNS(responseWriter, inputMSG)
elapsed := time.Since(start)
require.NotNil(t, responseMSG, "should write a response")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
require.NotEmpty(t, responseMSG.Answer)
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
// Working group answers in a single RTT; the broken group's
// timeout (100ms) must not block the response.
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
}
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
// resolver returns SERVFAIL rather than leaking a partial response.
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.1:53")
b := netip.MustParseAddrPort("192.0.2.2:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a})
resolver.addRace([]netip.AddrPort{b})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
require.NotNil(t, responseMSG)
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
}
// TestUpstreamResolver_HealthTracking verifies that query-path results are
// recorded into per-upstream health, which is what projects back to
// NSGroupState for status reporting.
func TestUpstreamResolver_HealthTracking(t *testing.T) {
ok := netip.MustParseAddrPort("192.0.2.10:53")
bad := netip.MustParseAddrPort("192.0.2.11:53")
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{ok, bad})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, ok)
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
assert.Empty(t, health[ok].LastErr)
// bad upstream was never tried because ok answered first; its health
// should remain unset.
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
}
func TestFormatFailures(t *testing.T) {
testCases := []struct {
name string
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
// Verify that a client EDNS0 larger than our MTU-derived limit gets
// capped in the outgoing request so the upstream doesn't send a
// response larger than our read buffer.
var receivedUDPSize uint16
var receivedUDPSize atomic.Uint32
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
if opt := r.IsEdns0(); opt != nil {
receivedUDPSize = opt.UDPSize()
receivedUDPSize.Store(uint32(opt.UDPSize()))
}
m := new(dns.Msg)
m.SetReply(r)
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
require.NotNil(t, rm)
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
assert.Equal(t, expectedMax, receivedUDPSize,
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
"upstream should see capped EDNS0, not the client's 4096")
}
@@ -770,3 +847,132 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
var queried []string
tracking := &trackingMockClient{
inner: &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): {msg: servfailWithEDE},
upstream2.String(): {msg: successResp},
},
rtt: time.Millisecond,
},
queriedUpstreams: &queried,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []upstreamRace{{upstream1, upstream2}},
upstreamTimeout: UpstreamTimeout,
}
var written *dns.Msg
w := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
written = m
return nil
},
}
// Client query without EDNS0 must not see an OPT in the response.
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(w, q)
require.NotNil(t, written, "response must be written")
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
assert.Len(t, queried, 1, "only first upstream should be queried")
assert.Equal(t, upstream1.String(), queried[0])
for _, rr := range written.Extra {
_, isOPT := rr.(*dns.OPT)
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
}
}

View File

@@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
for _, r := range routes {
if r.Network.Contains(ip) {
return true
}
}
}
return false
})
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
@@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
// If no server of a server group responds this will disable the respective handler and retry later.
go e.dnsServer.ProbeAvailability()
return nil
}
@@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
return dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return dnsServer, nil
default:

View File

@@ -53,6 +53,7 @@ type Manager interface {
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetSelectedClientRoutes() route.HAMap
GetActiveClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
@@ -485,6 +486,39 @@ func (m *DefaultManager) GetSelectedClientRoutes() route.HAMap {
return m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
}
// GetActiveClientRoutes returns the subset of selected client routes
// that are currently reachable: the route's peer is Connected and is
// the one actively carrying the route (not just an HA sibling).
func (m *DefaultManager) GetActiveClientRoutes() route.HAMap {
m.mux.Lock()
selected := m.routeSelector.FilterSelectedExitNodes(maps.Clone(m.clientRoutes))
recorder := m.statusRecorder
m.mux.Unlock()
if recorder == nil {
return selected
}
out := make(route.HAMap, len(selected))
for id, routes := range selected {
for _, r := range routes {
st, err := recorder.GetPeer(r.Peer)
if err != nil {
continue
}
if st.ConnStatus != peer.StatusConnected {
continue
}
if _, hasRoute := st.GetRoutes()[r.Network.String()]; !hasRoute {
continue
}
out[id] = routes
break
}
}
return out
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
@@ -704,7 +738,10 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
func (m *DefaultManager) isExitNodeRoute(routes []*route.Route) bool {
return len(routes) > 0 && routes[0].Network.String() == vars.ExitNodeCIDR
if len(routes) == 0 {
return false
}
return route.IsV4DefaultRoute(routes[0].Network) || route.IsV6DefaultRoute(routes[0].Network)
}
func (m *DefaultManager) categorizeUserSelection(netID route.NetID, info *exitNodeInfo) {

View File

@@ -19,6 +19,7 @@ type MockManager struct {
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetSelectedClientRoutesFunc func() route.HAMap
GetActiveClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
@@ -78,6 +79,14 @@ func (m *MockManager) GetSelectedClientRoutes() route.HAMap {
return nil
}
// GetActiveClientRoutes mock implementation of GetActiveClientRoutes from the Manager interface
func (m *MockManager) GetActiveClientRoutes() route.HAMap {
if m.GetActiveClientRoutesFunc != nil {
return m.GetActiveClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,10 +13,6 @@ import (
"github.com/netbirdio/netbird/route"
)
const (
exitNodeCIDR = "0.0.0.0/0"
)
type RouteSelector struct {
mu sync.RWMutex
deselectedRoutes map[route.NetID]struct{}
@@ -124,13 +121,7 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
return isSelected
return rs.isSelectedLocked(routeID)
}
// FilterSelected removes unselected routes from the provided map.
@@ -144,23 +135,22 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
filtered := route.HAMap{}
for id, rt := range routes {
netID := id.NetID()
_, deselected := rs.deselectedRoutes[netID]
if !deselected {
if !rs.isDeselectedLocked(id.NetID()) {
filtered[id] = rt
}
}
return filtered
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this specific route
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -174,7 +164,7 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
filtered := make(route.HAMap, len(routes))
for id, rt := range routes {
netID := id.NetID()
if rs.isDeselected(netID) {
if rs.isDeselectedLocked(netID) {
continue
}
@@ -189,13 +179,48 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
func (rs *RouteSelector) isDeselected(netID route.NetID) bool {
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected || rs.deselectAll
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && rt[0].Network.String() == exitNodeCIDR
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
@@ -204,26 +229,23 @@ func (rs *RouteSelector) applyExitNodeFilter(
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelections() {
// user made explicit selects/deselects
if rs.IsSelected(netID) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selections: only include routes marked !SkipAutoApply (=AutoApply)
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func (rs *RouteSelector) hasUserSelections() bool {
return len(rs.selectedRoutes) > 0 || len(rs.deselectedRoutes) > 0
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {

View File

@@ -330,6 +330,137 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
routes := route.HAMap{
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-default-route entry named "corp-v6" is not an exit node and
// must not be skipped because its v4 base "corp" is deselected.
corpV6 := &route.Route{NetID: "corp-v6", Network: netip.MustParsePrefix("10.0.0.0/8")}
routes := route.HAMap{
"corp-v6|10.0.0.0/8": {corpV6},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("corp-v6|10.0.0.0/8"),
"non-exit *-v6 routes must not inherit unrelated v4 state in FilterSelectedExitNodes")
})
}
// TestRouteSelector_SkipAutoApplyPerRoute verifies that management's
// SkipAutoApply flag governs each untouched route independently, even when
// the user has explicit selections on other routes.
func TestRouteSelector_SkipAutoApplyPerRoute(t *testing.T) {
autoApplied := &route.Route{
NetID: "Auto",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: false,
}
skipApply := &route.Route{
NetID: "Skip",
Network: netip.MustParsePrefix("0.0.0.0/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"Auto|0.0.0.0/0": {autoApplied},
"Skip|0.0.0.0/0": {skipApply},
}
rs := routeselector.NewRouteSelector()
// User makes an unrelated explicit selection elsewhere.
require.NoError(t, rs.DeselectRoutes([]route.NetID{"Unrelated"}, []route.NetID{"Auto", "Skip", "Unrelated"}))
filtered := rs.FilterSelectedExitNodes(routes)
assert.Contains(t, filtered, route.HAUniqueID("Auto|0.0.0.0/0"), "AutoApply route should be included")
assert.NotContains(t, filtered, route.HAUniqueID("Skip|0.0.0.0/0"), "SkipAutoApply route should be excluded without explicit user selection")
}
// TestRouteSelector_V6ExitIsExitNode verifies that ::/0 routes are recognized
// as exit nodes by the selector's filter path.
func TestRouteSelector_V6ExitIsExitNode(t *testing.T) {
v6Exit := &route.Route{
NetID: "V6Only",
Network: netip.MustParsePrefix("::/0"),
SkipAutoApply: true,
}
routes := route.HAMap{
"V6Only|::/0": {v6Exit},
}
rs := routeselector.NewRouteSelector()
filtered := rs.FilterSelectedExitNodes(routes)
assert.Empty(t, filtered, "::/0 should be treated as an exit node and respect SkipAutoApply")
}
func TestRouteSelector_NewRoutesBehavior(t *testing.T) {
initialRoutes := []route.NetID{"route1", "route2", "route3"}
newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"}

View File

@@ -162,11 +162,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
hostDNS := []netip.AddrPort{
netip.MustParseAddrPort("9.9.9.9:53"),
netip.MustParseAddrPort("149.112.112.112:53"),
}
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, hostDNS, c.stateFile)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}
// Stop the internal client and free the resources

View File

@@ -64,6 +64,13 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -76,10 +83,28 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

View File

@@ -0,0 +1,93 @@
package server
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/profilemanager"
)
func TestPersistLoginOverrides(t *testing.T) {
strPtr := func(s string) *string { return &s }
tests := []struct {
name string
initialMgmtURL string
initialPSK string
newMgmtURL string
newPSK *string
wantMgmtURL string
wantPSK string
}{
{
name: "persist new management URL",
initialMgmtURL: "https://old.example.com:33073",
newMgmtURL: "https://new.example.com:33073",
wantMgmtURL: "https://new.example.com:33073",
},
{
name: "persist new pre-shared key",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "old-key",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "new-key",
},
{
name: "persist both",
initialMgmtURL: "https://old.example.com:33073",
initialPSK: "old-key",
newMgmtURL: "https://new.example.com:33073",
newPSK: strPtr("new-key"),
wantMgmtURL: "https://new.example.com:33073",
wantPSK: "new-key",
},
{
name: "no inputs preserves existing",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
{
name: "empty PSK pointer is ignored",
initialMgmtURL: "https://existing.example.com:33073",
initialPSK: "existing-key",
newPSK: strPtr(""),
wantMgmtURL: "https://existing.example.com:33073",
wantPSK: "existing-key",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origDefault := profilemanager.DefaultConfigPath
t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault })
dir := t.TempDir()
profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json")
seed := profilemanager.ConfigInput{
ConfigPath: profilemanager.DefaultConfigPath,
ManagementURL: tt.initialMgmtURL,
}
if tt.initialPSK != "" {
seed.PreSharedKey = strPtr(tt.initialPSK)
}
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")
cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath)
require.NoError(t, err, "read back config")
require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL")
require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key")
})
}
}

View File

@@ -490,6 +490,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.mutex.Unlock()
if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil {
log.Errorf("failed to persist login overrides: %v", err)
return nil, fmt.Errorf("persist login overrides: %w", err)
}
config, _, err := s.getConfig(activeProf)
if err != nil {
log.Errorf("failed to get active profile config: %v", err)
@@ -964,7 +969,7 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe
return &proto.LogoutResponse{}, nil
}
// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist
func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) {
cfgPath, err := activeProf.FilePath()
if err != nil {
@@ -1766,3 +1771,29 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the
// active profile config so that subsequent reads pick them up. Empty/nil values are ignored.
func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error {
if preSharedKey != nil && *preSharedKey == "" {
preSharedKey = nil
}
if managementURL == "" && preSharedKey == nil {
return nil
}
cfgPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("active profile file path: %w", err)
}
input := profilemanager.ConfigInput{
ConfigPath: cfgPath,
ManagementURL: managementURL,
PreSharedKey: preSharedKey,
}
if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil {
return fmt.Errorf("update config: %w", err)
}
return nil
}

View File

@@ -25,6 +25,7 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/netrelay"
)
const (
@@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
continue
}
go c.handleLocalForward(localConn, remoteAddr)
go c.handleLocalForward(ctx, localConn, remoteAddr)
}
}()
@@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
}
// handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local port forwarding: close local connection: %v", err)
@@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
}
}()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
select {
case <-ctx.Done():
return
case newChan := <-channelRequests:
case newChan, ok := <-channelRequests:
if !ok {
return
}
if newChan != nil {
go c.handleRemoteForwardChannel(newChan, localAddr)
go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
}
}
}
}
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept()
if err != nil {
return
@@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
go ssh.DiscardRequests(reqs)
localConn, err := net.Dial("tcp", localAddr)
// Bound the dial so a black-holed localAddr can't pin the accepted SSH
// channel open indefinitely; the relay itself runs under the outer ctx.
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
var dialer net.Dialer
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
cancelDial()
if err != nil {
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
return
}
defer func() {
@@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
}
}()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
// tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string {
return addresses
}
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -229,18 +229,35 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
sshConfigPathTmp := sshConfigPath + ".tmp"
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
tmp, err := os.CreateTemp(m.sshConfigDir, m.sshConfigFile+".*.tmp")
if err != nil {
return fmt.Errorf("create temp SSH config: %w", err)
}
tmpPath := tmp.Name()
defer func() {
if err := os.Remove(tmpPath); err != nil && !os.IsNotExist(err) {
log.Debugf("remove temp SSH config %s: %v", tmpPath, err)
}
}()
if err := tmp.Close(); err != nil {
return fmt.Errorf("close temp SSH config %s: %w", tmpPath, err)
}
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
if err := writeFileWithTimeout(tmpPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", tmpPath, err)
}
if err := os.Chmod(tmpPath, 0644); err != nil {
return fmt.Errorf("chmod SSH config file %s: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, sshConfigPath); err != nil {
return fmt.Errorf("rename SSH config %s -> %s: %w", tmpPath, sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -17,7 +17,7 @@ import (
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util/netrelay"
)
const privilegedPortThreshold = 1024
@@ -357,7 +357,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
return
}
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel)
netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
}
// openForwardChannel creates an SSH forwarded-tcpip channel

View File

@@ -8,9 +8,9 @@ import (
"fmt"
"io"
"net"
"strconv"
"net/netip"
"slices"
"strconv"
"strings"
"sync"
"time"
@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -53,6 +54,10 @@ const (
DefaultJWTMaxTokenAge = 10 * 60
)
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
// the forwarded destination before rejecting the SSH channel.
const directTCPIPDialTimeout = 30 * time.Second
var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found")
@@ -933,5 +938,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s", hostPort)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
}
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
// DirectTCPIPHandler. The upstream handler closes both sides on the first
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
// own terms.
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
dest := net.JoinHostPort(host, strconv.Itoa(port))
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go cryptossh.DiscardRequests(reqs)
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
}

View File

@@ -193,7 +193,15 @@ func getOverlappingNetworks(routes []*proto.Network) []*proto.Network {
}
func isDefaultRoute(routeRange string) bool {
return routeRange == "0.0.0.0/0" || routeRange == "::/0"
// routeRange is the merged display string from the daemon, e.g. "0.0.0.0/0",
// "::/0", or "0.0.0.0/0, ::/0" when a v4 exit node has a paired v6 entry.
for _, part := range strings.Split(routeRange, ",") {
switch strings.TrimSpace(part) {
case "0.0.0.0/0", "::/0":
return true
}
}
return false
}
func getExitNodeNetworks(routes []*proto.Network) []*proto.Network {

View File

@@ -133,13 +133,18 @@ type ManagementConfig struct {
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
MfaSessionMaxLifetime string `yaml:"mfaSessionMaxLifetime"`
MfaSessionIdleTimeout string `yaml:"mfaSessionIdleTimeout"`
MfaSessionRememberMe bool `yaml:"mfaSessionRememberMe"`
SessionCookieEncryptionKey string `yaml:"sessionCookieEncryptionKey"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
DashboardPostLogoutRedirectURIs []string `yaml:"dashboardPostLogoutRedirectURIs"`
}
// AuthStorageConfig contains auth storage settings
@@ -581,10 +586,14 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
}
cfg := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
MfaSessionMaxLifetime: mgmt.Auth.MfaSessionMaxLifetime,
MfaSessionIdleTimeout: mgmt.Auth.MfaSessionIdleTimeout,
MfaSessionRememberMe: mgmt.Auth.MfaSessionRememberMe,
SessionCookieEncryptionKey: mgmt.Auth.SessionCookieEncryptionKey,
Storage: idp.EmbeddedStorageConfig{
Type: authStorageType,
Config: idp.EmbeddedStorageTypeConfig{
@@ -592,8 +601,9 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
DSN: authStorageDSN,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
DashboardPostLogoutRedirectURIs: mgmt.Auth.DashboardPostLogoutRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {

View File

@@ -86,6 +86,13 @@ server:
issuer: "https://example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# MFA session settings (applies when TOTP is enabled for an account)
# mfaSessionMaxLifetime: "24h" # Max duration for an MFA session from creation
# mfaSessionIdleTimeout: "1h" # MFA session expires after this idle period
# mfaSessionRememberMe: false # Pre-check "remember me" on login so the MFA session persists across tabs/restarts
# Optional AES key for encrypting embedded IdP session cookies. Can also be set via NB_IDP_SESSION_COOKIE_ENCRYPTION_KEY.
# Must be 16/24/32 raw bytes or base64-encoded to one of those lengths (for example: openssl rand -hex 16).
# sessionCookieEncryptionKey: ""
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
@@ -93,6 +100,9 @@ server:
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# OAuth2 post-logout redirect URIs for dashboard (RP-initiated logout)
# dashboardPostLogoutRedirectURIs:
# - "https://app.example.com/"
# Optional initial admin user
# owner:
# email: "admin@example.com"

46
go.mod
View File

@@ -14,7 +14,7 @@ require (
github.com/onsi/gomega v1.27.6
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.4
github.com/spf13/cobra v1.10.1
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.9
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
@@ -41,11 +41,11 @@ require (
github.com/cilium/ebpf v0.15.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.14.1
github.com/coreos/go-oidc/v3 v3.18.0
github.com/creack/pty v1.1.24
github.com/crowdsecurity/crowdsec v1.7.7
github.com/crowdsecurity/go-cs-bouncer v0.0.21
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex v2.13.0+incompatible
github.com/dexidp/dex/api/v2 v2.4.0
github.com/ebitengine/purego v0.8.4
github.com/eko/gocache/lib/v4 v4.2.0
@@ -53,9 +53,9 @@ require (
github.com/eko/gocache/store/redis/v4 v4.2.2
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-jose/go-jose/v4 v4.1.4
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang-jwt/jwt/v5 v5.3.1
github.com/golang/mock v1.6.0
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
@@ -72,7 +72,7 @@ require (
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.59
github.com/miekg/dns v1.1.72
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
@@ -113,7 +113,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.64.0
go.opentelemetry.io/otel/metric v1.43.0
go.opentelemetry.io/otel/sdk/metric v1.43.0
go.uber.org/mock v0.5.2
go.uber.org/mock v0.6.0
go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
@@ -141,7 +141,7 @@ require (
filippo.io/edwards25519 v1.1.1 // indirect
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/Azure/go-ntlmssp v0.1.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
@@ -168,6 +168,7 @@ require (
github.com/aws/smithy-go v1.23.0 // indirect
github.com/beevik/etree v1.6.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
@@ -183,6 +184,7 @@ require (
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.1 // indirect
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect
github.com/fyne-io/glfw-js v0.3.0 // indirect
github.com/fyne-io/image v0.1.1 // indirect
@@ -190,7 +192,7 @@ require (
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-ldap/ldap/v3 v3.4.12 // indirect
github.com/go-ldap/ldap/v3 v3.4.13 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
@@ -206,11 +208,15 @@ require (
github.com/go-sql-driver/mysql v1.9.3 // indirect
github.com/go-text/render v0.2.0 // indirect
github.com/go-text/typesetting v0.2.1 // indirect
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/goccy/go-yaml v1.18.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/go-tpm v0.9.8 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect
github.com/googleapis/gax-go/v2 v2.21.0 // indirect
@@ -218,7 +224,13 @@ require (
github.com/hack-pad/go-indexeddb v0.3.2 // indirect
github.com/hack-pad/safejs v0.1.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
github.com/hashicorp/go-sockaddr v1.0.7 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/hashicorp/hcl v1.0.1-vault-7 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -238,13 +250,13 @@ require (
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/koron/go-ssdp v0.0.4 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/lib/pq v1.12.3 // indirect
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.10 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect
github.com/mattn/go-sqlite3 v1.14.32 // indirect
github.com/mattn/go-sqlite3 v1.14.42 // indirect
github.com/mdelapenya/tlscert v0.2.0 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect
@@ -265,8 +277,10 @@ require (
github.com/nxadm/tail v1.4.11 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/openbao/openbao/api/v2 v2.5.1 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
@@ -275,11 +289,13 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/otlptranslator v1.0.0 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/russellhaering/goxmldsig v1.6.0 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/rymdport/portal v0.4.2 // indirect
github.com/shirou/gopsutil/v4 v4.25.8 // indirect
github.com/shoenig/go-m1cpu v0.2.1 // indirect
@@ -288,11 +304,13 @@ require (
github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect
github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tinylib/msgp v1.6.3 // indirect
github.com/tklauser/go-sysconf v0.3.15 // indirect
github.com/tklauser/numcpus v0.10.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.mongodb.org/mongo-driver v1.17.9 // indirect
@@ -319,10 +337,12 @@ replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-202
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1
replace github.com/dexidp/dex/api/v2 => github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

101
go.sum
View File

@@ -5,6 +5,8 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3R
cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
@@ -23,8 +25,8 @@ github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
@@ -91,6 +93,8 @@ github.com/beevik/etree v1.6.0/go.mod h1:bh4zJxiIr62SOf9pRzN7UUYaEDa9HEKafK25+sL
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
@@ -117,8 +121,8 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8=
github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
github.com/coreos/go-oidc/v3 v3.14.1/go.mod h1:HaZ3szPaZ0e4r6ebqvsLWlk2Tn+aejfmrfah6hnSYEU=
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
@@ -130,14 +134,10 @@ github.com/crowdsecurity/go-cs-bouncer v0.0.21 h1:arPz0VtdVSaz+auOSfHythzkZVLyy1
github.com/crowdsecurity/go-cs-bouncer v0.0.21/go.mod h1:4JiH0XXA4KKnnWThItUpe5+heJHWzsLOSA2IWJqUDBA=
github.com/crowdsecurity/go-cs-lib v0.0.25 h1:Ov6VPW9yV+OPsbAIQk1iTkEWhwkpaG0v3lrBzeqjzj4=
github.com/crowdsecurity/go-cs-lib v0.0.25/go.mod h1:X0GMJY2CxdA1S09SpuqIKaWQsvRGxXmecUp9cP599dE=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:/DS5cDX3FJdl+XaN2D7XAwFpuanTxnp52DBLZAaJKx0=
github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dexidp/dex/api/v2 v2.4.0 h1:gNba7n6BKVp8X4Jp24cxYn5rIIGhM6kDOXcZoL6tr9A=
github.com/dexidp/dex/api/v2 v2.4.0/go.mod h1:/p550ADvFFh7K95VmhUD+jgm15VdaNnab9td8DHOpyI=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
@@ -156,6 +156,8 @@ github.com/eko/gocache/store/go_cache/v4 v4.2.2 h1:tAI9nl6TLoJyKG1ujF0CS0n/IgTEM
github.com/eko/gocache/store/go_cache/v4 v4.2.2/go.mod h1:T9zkHokzr8K9EiC7RfMbDg6HSwaV6rv3UdcNu13SGcA=
github.com/eko/gocache/store/redis/v4 v4.2.2 h1:Thw31fzGuH3WzJywsdbMivOmP550D6JS7GDHhvCJPA0=
github.com/eko/gocache/store/redis/v4 v4.2.2/go.mod h1:LaTxLKx9TG/YUEybQvPMij++D7PBTIJ4+pzvk0ykz0w=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
@@ -171,6 +173,8 @@ github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
github.com/fyne-io/gl-js v0.2.0 h1:+EXMLVEa18EfkXBVKhifYB6OGs3HwKO3lUElA0LlAjs=
github.com/fyne-io/gl-js v0.2.0/go.mod h1:ZcepK8vmOYLu96JoxbCKJy2ybr+g1pTnaBDdl7c3ajI=
github.com/fyne-io/glfw-js v0.3.0 h1:d8k2+Y7l+zy2pc7wlGRyPfTgZoqDf3AI4G+2zOWhWUk=
@@ -189,10 +193,10 @@ github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71 h1:5BVwOaUSBTlVZowGO6VZGw
github.com/go-gl/gl v0.0.0-20231021071112-07e5d0ea2e71/go.mod h1:9YTyiznxEY1fVinfM7RvRcjRHbw2xLBJ3AAGIT0I4Nw=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs=
github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
github.com/go-ldap/ldap/v3 v3.4.12/go.mod h1:+SPAGcTtOfmGsCb3h1RFiq4xpp4N636G75OEace8lNo=
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
github.com/go-ldap/ldap/v3 v3.4.13 h1:+x1nG9h+MZN7h/lUi5Q3UZ0fJ1GyDQYbPvbuH38baDQ=
github.com/go-ldap/ldap/v3 v3.4.13/go.mod h1:LxsGZV6vbaK0sIvYfsv47rfh4ca0JXokCoKjZxsszv0=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -229,12 +233,20 @@ github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI6
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc=
github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU=
github.com/go-text/typesetting v0.2.1 h1:x0jMOGyO3d1qFAPI0j4GSsh7M0Q3Ypjzr4+CEVg82V8=
github.com/go-text/typesetting v0.2.1/go.mod h1:mTOxEwasOFpAMBjEQDhdWRckoLLeI/+qrQeBCTGEt6M=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066 h1:qCuYC+94v2xrb1PoS4NIDe7DGYtLnU2wWiQe9a1B1c0=
github.com/go-text/typesetting-utils v0.0.0-20241103174707-87a29e9e6066/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o=
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4CyGN+1Q=
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@@ -243,8 +255,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -276,6 +288,10 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/go-tpm v0.9.8 h1:slArAR9Ft+1ybZu0lBwpSmpwhRXaa85hWtMinMyRAWo=
github.com/google/go-tpm v0.9.8/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY=
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc=
github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo=
@@ -308,15 +324,29 @@ github.com/hack-pad/safejs v0.1.0/go.mod h1:HdS+bKF1NrE72VoXZeWzxFOVQVUSqZJAG0xN
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48=
github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4=
github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw=
github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY=
github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I=
github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
@@ -387,8 +417,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.12.3 h1:tTWxr2YLKwIvK90ZXEw8GP7UFHtcbTtty8zsI+YjrfQ=
github.com/lib/pq v1.12.3/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
@@ -406,9 +436,13 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo=
github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI=
github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@@ -421,8 +455,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
@@ -451,8 +485,10 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1 h1:4TaYr9O4xX0D2kszeOLclTiCbA3eHq3xWV+9ILJbIYs=
github.com/netbirdio/dex v0.244.1-0.20260512110716-8d70ad8647c1/go.mod h1:IHH+H8vK2GfqtIt5u/5OdPh18yk0oDHuj2vz5+Goetg=
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1 h1:neE7z+FPUkldl3faK/Jt+hJK2L+1XfQ1W33TQhU9m88=
github.com/netbirdio/dex/api/v2 v2.0.0-20260512110716-8d70ad8647c1/go.mod h1:awuTyT29CYALpEyET0S307EgNlPWrc7fFKRAyhsO45M=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
@@ -489,6 +525,8 @@ github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7J
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o=
github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
@@ -501,6 +539,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs=
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4=
github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM=
github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
@@ -542,6 +582,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
@@ -565,6 +607,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russellhaering/goxmldsig v1.6.0 h1:8fdWXEPh2k/NZNQBPFNoVfS3JmzS4ZprY/sAOpKQLks=
github.com/russellhaering/goxmldsig v1.6.0/go.mod h1:TrnaquDcYxWXfJrOjeMBTX4mLBeYAqaHEyUeWPxZlBM=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/rymdport/portal v0.4.2 h1:7jKRSemwlTyVHHrTGgQg7gmNPJs88xkbKcIL3NlcmSU=
github.com/rymdport/portal v0.4.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4=
github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU=
@@ -587,8 +631,8 @@ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D
github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
@@ -628,6 +672,8 @@ github.com/ti-mo/conntrack v0.5.1 h1:opEwkFICnDbQc0BUXl73PHBK0h23jEIFVjXsqvF4GY0
github.com/ti-mo/conntrack v0.5.1/go.mod h1:T6NCbkMdVU4qEIgwL0njA6lw/iCAbzchlnwm1Sa314o=
github.com/ti-mo/netfilter v0.5.2 h1:CTjOwFuNNeZ9QPdRXt1MZFLFUf84cKtiQutNauHWd40=
github.com/ti-mo/netfilter v0.5.2/go.mod h1:Btx3AtFiOVdHReTDmP9AE+hlkOcvIy403u7BXXbWZKo=
github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s=
github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8Ol49K4=
github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4=
@@ -646,6 +692,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAh
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
@@ -690,14 +738,15 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko=
go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View File

@@ -51,6 +51,70 @@ type YAMLConfig struct {
// StaticPasswords cause the server use this list of passwords rather than
// querying the storage.
StaticPasswords []Password `yaml:"staticPasswords" json:"staticPasswords"`
// Sessions holds authentication session configuration.
// Requires DEX_SESSIONS_ENABLED=true feature flag.
Sessions *Sessions `yaml:"sessions" json:"sessions"`
// MFA holds multi-factor authentication configuration.
MFA MFAConfig `yaml:"mfa" json:"mfa"`
}
type Sessions struct {
// CookieName is the name of the session cookie. Defaults to "dex_session".
CookieName string `yaml:"cookieName" json:"cookieName"`
// AbsoluteLifetime is the maximum session lifetime from creation. Defaults to "24h".
AbsoluteLifetime string `yaml:"absoluteLifetime" json:"absoluteLifetime"`
// ValidIfNotUsedFor is the idle timeout. Defaults to "1h".
ValidIfNotUsedFor string `yaml:"validIfNotUsedFor" json:"validIfNotUsedFor"`
// RememberMeCheckedByDefault controls the default state of the "remember me" checkbox.
RememberMeCheckedByDefault *bool `yaml:"rememberMeCheckedByDefault" json:"rememberMeCheckedByDefault"`
// CookieEncryptionKey is the AES key for encrypting session cookies.
// Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256.
// If empty, cookies are not encrypted.
CookieEncryptionKey string `yaml:"cookieEncryptionKey" json:"cookieEncryptionKey"`
// SSOSharedWithDefault is the default SSO sharing policy for clients without explicit ssoSharedWith.
// "all" = share with all clients, "none" = share with no one (default: "none").
SSOSharedWithDefault string `yaml:"ssoSharedWithDefault" json:"ssoSharedWithDefault"`
}
type MFAConfig struct {
Authenticators []MFAAuthenticator `yaml:"authenticators" json:"authenticators"`
}
type MFAAuthenticator struct {
ID string `yaml:"id" json:"id"`
Type string `yaml:"type" json:"type"`
Config map[string]interface{} `yaml:"config" json:"config"`
ConnectorTypes []string `yaml:"connectorTypes" json:"connectorTypes"`
}
type TOTPConfig struct {
Issuer string `yaml:"issuer" json:"issuer"`
}
// WebAuthnConfig holds configuration for a WebAuthn authenticator.
type WebAuthnConfig struct {
// RPDisplayName is the human-readable relying party name shown in the browser
// dialog during key registration and authentication (e.g., "My Company SSO").
RPDisplayName string `yaml:"rpDisplayName" json:"rpDisplayName"`
// RPID is the relying party identifier — must match the domain in the browser
// address bar. If empty, derived from the issuer URL hostname.
// Example: "auth.example.com"
RPID string `yaml:"rpID" json:"rpID"`
// RPOrigins is the list of allowed origins for WebAuthn ceremonies.
// If empty, derived from the issuer URL (scheme + host).
// Example: ["https://auth.example.com"]
RPOrigins []string `yaml:"rpOrigins" json:"rpOrigins"`
// AttestationPreference controls what attestation data the authenticator should provide:
// "none" — don't request attestation (simpler, more private)
// "indirect" — authenticator may anonymize attestation (default)
// "direct" — request full attestation (for enterprise key model verification)
AttestationPreference string `yaml:"attestationPreference" json:"attestationPreference"`
// Timeout is the duration allowed for the browser WebAuthn ceremony
// (registration or login). Defaults to "60s".
Timeout string `yaml:"timeout" json:"timeout"`
}
// Web is the config format for the HTTP server.
@@ -116,7 +180,6 @@ type Storage struct {
Config map[string]interface{} `yaml:"config" json:"config"`
}
// Password represents a static user configuration
type Password storage.Password
func (p *Password) UnmarshalYAML(node *yaml.Node) error {
@@ -429,9 +492,98 @@ func (c *YAMLConfig) Validate() error {
if !c.EnablePasswordDB && len(c.StaticPasswords) != 0 {
return fmt.Errorf("cannot specify static passwords without enabling password db")
}
return nil
}
func buildTotpConfig(auth MFAAuthenticator) (*server.TOTPProvider, error) {
data, err := json.Marshal(auth.Config)
if err != nil {
return nil, fmt.Errorf("failed to marshal TOTP config id: %s - %w", auth.ID, err)
}
var cfg TOTPConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse TOTP config id: %s - %w", auth.ID, err)
}
return server.NewTOTPProvider(cfg.Issuer, auth.ConnectorTypes), nil
}
func buildWebAuthnConfig(auth MFAAuthenticator, issuerURL string) (*server.WebAuthnProvider, error) {
data, err := json.Marshal(auth.Config)
if err != nil {
return nil, fmt.Errorf("failed to marshal WebAuthn config id: %s - %w", auth.ID, err)
}
var cfg WebAuthnConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("failed to parse WebAuthn config id: %s - %w", auth.ID, err)
}
provider, err := server.NewWebAuthnProvider(cfg.RPDisplayName, cfg.RPID, cfg.RPOrigins,
cfg.AttestationPreference, cfg.Timeout, issuerURL, auth.ConnectorTypes)
if err != nil {
return nil, fmt.Errorf("failed to create WebAuthn provider id: %s - err: %w", auth.ID, err)
}
return provider, nil
}
func buildMFAProviders(authenticators []MFAAuthenticator, issuerURL string, logger *slog.Logger) map[string]server.MFAProvider {
if len(authenticators) == 0 {
return nil
}
providers := make(map[string]server.MFAProvider, len(authenticators))
for _, auth := range authenticators {
switch auth.Type {
case "TOTP":
provider, err := buildTotpConfig(auth)
if err != nil {
logger.Error("failed to parse TOTP config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = provider
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
case "WebAuthn":
provider, err := buildWebAuthnConfig(auth, issuerURL)
if err != nil {
logger.Error("failed to parse WebAuthn config", "id", auth.ID, "err", err)
continue
}
providers[auth.ID] = provider
logger.Info("MFA authenticator configured", "id", auth.ID, "type", auth.Type)
default:
logger.Error("unknown MFA authenticator type, skipping", "id", auth.ID, "type", auth.Type)
}
}
return providers
}
func buildSessionsConfig(sessions *Sessions) *server.SessionConfig {
if sessions == nil {
return nil
}
if sessions.RememberMeCheckedByDefault == nil {
defaultRememberMeCheckedByDefault := false
sessions.RememberMeCheckedByDefault = &defaultRememberMeCheckedByDefault
}
absoluteLifetime, _ := parseDuration(sessions.AbsoluteLifetime)
validIfNotUsedFor, _ := parseDuration(sessions.ValidIfNotUsedFor)
return &server.SessionConfig{
CookieEncryptionKey: []byte(sessions.CookieEncryptionKey),
CookieName: sessions.CookieName,
AbsoluteLifetime: absoluteLifetime,
ValidIfNotUsedFor: validIfNotUsedFor,
RememberMeCheckedByDefault: *sessions.RememberMeCheckedByDefault,
SSOSharedWithDefault: sessions.SSOSharedWithDefault,
}
}
// ToServerConfig converts YAMLConfig to dex server.Config
func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) server.Config {
cfg := server.Config{
@@ -448,6 +600,8 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
Dir: c.Frontend.Dir,
Extra: c.Frontend.Extra,
},
SessionConfig: buildSessionsConfig(c.Sessions),
MFAProviders: buildMFAProviders(c.MFA.Authenticators, c.Issuer, logger),
}
// Use embedded NetBird-styled templates if no custom dir specified
@@ -460,11 +614,6 @@ func (c *YAMLConfig) ToServerConfig(stor storage.Storage, logger *slog.Logger) s
}
// Apply expiry settings
if c.Expiry.SigningKeys != "" {
if d, err := parseDuration(c.Expiry.SigningKeys); err == nil {
cfg.RotateKeysAfter = d
}
}
if c.Expiry.IDTokens != "" {
if d, err := parseDuration(c.Expiry.IDTokens); err == nil {
cfg.IDTokensValidFor = d

View File

@@ -18,6 +18,7 @@ import (
dexapi "github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
@@ -70,7 +71,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
// Ensure data directory exists
if err := os.MkdirAll(config.DataDir, 0700); err != nil {
if err := os.MkdirAll(config.DataDir, 0o700); err != nil {
return nil, fmt.Errorf("failed to create data directory: %w", err)
}
@@ -101,6 +102,15 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
}
localSignerConfig := signer.LocalConfig{
KeysRotationPeriod: "6h",
}
localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger)
if err != nil {
return nil, fmt.Errorf("failed to create local signer: %w", err)
}
// Build Dex server config - use Dex's types directly
dexConfig := server.Config{
Issuer: issuer,
@@ -110,12 +120,12 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
ContinueOnConnectorFailure: true,
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
RotateKeysAfter: 6 * time.Hour,
IDTokensValidFor: 24 * time.Hour,
RefreshTokenPolicy: refreshPolicy,
Web: server.WebConfig{
Issuer: "NetBird",
},
Signer: localSigner,
}
dexSrv, err := server.NewServer(ctx, dexConfig)
@@ -167,6 +177,14 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
}
localSigner, err := getSigner(ctx, stor, yamlConfig, logger)
if err != nil {
stor.Close()
return nil, fmt.Errorf("failed to create local signer: %w", err)
}
dexConfig.Signer = localSigner
dexSrv, err := server.NewServer(ctx, dexConfig)
if err != nil {
stor.Close()
@@ -182,6 +200,32 @@ func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider
}, nil
}
func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) {
// Parse expiry durations
idTokensValidFor := 24 * time.Hour // default
if yamlConfig.Expiry.IDTokens != "" {
var err error
idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens)
if err != nil {
return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err)
}
}
localSignerConfig := &signer.LocalConfig{
KeysRotationPeriod: "720h", // 30 Days
}
if yamlConfig.Expiry.SigningKeys != "" {
if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil {
return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err)
}
localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys
}
return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger)
}
// initializeStorage sets up connectors, passwords, and clients in storage
func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error {
if cfg.EnablePasswordDB {
@@ -241,6 +285,8 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
old.RedirectURIs = client.RedirectURIs
old.Name = client.Name
old.Public = client.Public
old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs
old.MFAChain = client.MFAChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update client %s: %w", client.ID, err)
@@ -253,9 +299,6 @@ func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []st
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
cfg := yamlConfig.ToServerConfig(stor, logger)
cfg.PrometheusRegistry = prometheus.NewRegistry()
if cfg.RotateKeysAfter == 0 {
cfg.RotateKeysAfter = 24 * 30 * time.Hour
}
if cfg.IDTokensValidFor == 0 {
cfg.IDTokensValidFor = 24 * time.Hour
}
@@ -450,10 +493,34 @@ func (p *Provider) Storage() storage.Storage {
return p.storage
}
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
for _, clientID := range clientIDs {
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
}
}
return nil
}
// Handler returns the Dex server as an http.Handler for embedding in another server.
// The handler expects requests with path prefix "/oauth2/".
func (p *Provider) Handler() http.Handler {
return p.dexServer
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Dex's /logout endpoint requires id_token_hint for RP-initiated logout with
// post_logout_redirect_uri. If the dashboard calls logout without one, avoid
// rendering Dex's non-actionable Bad Request page and send the user home.
if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" {
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
p.dexServer.ServeHTTP(w, r)
})
}
// CreateUser creates a new user with the given email, username, and password.

View File

@@ -4,6 +4,8 @@ import (
"context"
"encoding/json"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
@@ -144,6 +146,30 @@ func TestEncodeDexUserID_MatchesDexFormat(t *testing.T) {
assert.Equal(t, knownEncodedID, reEncoded)
}
func TestHandlerRedirectsLogoutWithoutIDTokenHint(t *testing.T) {
ctx := context.Background()
tmpDir, err := os.MkdirTemp("", "dex-logout-handler-*")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)
provider, err := NewProvider(ctx, &Config{
Issuer: "http://localhost:5556/oauth2",
Port: 5556,
DataDir: tmpDir,
})
require.NoError(t, err)
defer func() { _ = provider.Stop(ctx) }()
req := httptest.NewRequest(http.MethodGet, "/oauth2/logout?post_logout_redirect_uri=https://example.com", nil)
rec := httptest.NewRecorder()
provider.Handler().ServeHTTP(rec, req)
require.Equal(t, http.StatusSeeOther, rec.Code)
require.Equal(t, "/", rec.Header().Get("Location"))
}
func TestCreateUserInTempDB(t *testing.T) {
ctx := context.Background()

View File

@@ -0,0 +1,12 @@
{{ template "header.html" . }}
<script>globalThis.location.replace("/");</script>
<noscript>
<div class="nb-card">
<h1 class="nb-heading">Redirecting…</h1>
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
</div>
</noscript>
{{ template "footer.html" . }}

View File

@@ -0,0 +1,14 @@
{{ template "header.html" . }}
<div class="nb-card">
<h1 class="nb-heading">Logged Out</h1>
<p class="nb-subheading">You have been successfully logged out.</p>
{{ if .BackURL }}
<div class="nb-back-link">
<a href="{{ .BackURL }}" class="nb-link">&larr; Back to Application</a>
</div>
{{ end }}
</div>
{{ template "footer.html" . }}

View File

@@ -18,6 +18,7 @@
id="login"
name="login"
class="nb-input"
autocomplete="username"
placeholder="Enter your {{ .UsernamePrompt | lower }}"
{{ if .Username }}value="{{ .Username }}"{{ else }}autofocus{{ end }}
required
@@ -31,6 +32,7 @@
id="password"
name="password"
class="nb-input"
autocomplete="current-password"
placeholder="Enter your password"
{{ if .Invalid }}autofocus{{ end }}
required

View File

@@ -0,0 +1,44 @@
{{ template "header.html" . }}
<div class="nb-card">
<h1 class="nb-heading">Two-factor authentication</h1>
{{ if not (eq .QRCode "") }}
<p class="nb-subheading">Scan the QR code below using your authenticator app, then enter the code.</p>
<div style="text-align: center; margin: 1em 0;">
<img src="data:image/png;base64,{{ .QRCode }}" alt="QR code" width="200" height="200"/>
</div>
{{ else }}
<p class="nb-subheading">Enter the code from your authenticator app.</p>
{{ end }}
<form method="post" action="{{ .PostURL }}">
{{ if .Invalid }}
<div class="nb-error">
Invalid code. Please try again.
</div>
{{ end }}
<div class="nb-form-group">
<label class="nb-label" for="totp">One-time code</label>
<input
type="text"
id="totp"
name="totp"
class="nb-input"
inputmode="numeric"
pattern="[0-9]*"
maxlength="6"
autocomplete="one-time-code"
placeholder="000000"
autofocus
required
>
</div>
<button type="submit" id="submit-login" class="nb-btn">
Verify
</button>
</form>
</div>
{{ template "footer.html" . }}

View File

@@ -0,0 +1,12 @@
{{ template "header.html" . }}
<script>globalThis.location.replace("/");</script>
<noscript>
<div class="nb-card">
<h1 class="nb-heading">Redirecting…</h1>
<p class="nb-subheading">You are being redirected to the NetBird dashboard.</p>
<a href="/" class="nb-btn" style="display:block;text-align:center;text-decoration:none">Go to Dashboard</a>
</div>
</noscript>
{{ template "footer.html" . }}

View File

@@ -31,6 +31,7 @@ type store interface {
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
@@ -71,8 +72,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
var ret []*domain.Domain
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// For BYOP accounts, only their own cluster is returned; otherwise shared clusters.
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
@@ -126,8 +127,8 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
return nil, status.NewPermissionDeniedError()
}
// Verify the target cluster is in the available clusters
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
// Verify the target cluster is in the available clusters for this account
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -273,7 +274,7 @@ func (m Manager) GetClusterDomains() []string {
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
allowList, err := m.getClusterAllowList(ctx, accountID)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
@@ -298,6 +299,34 @@ func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain
return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain)
}
func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]string, error) {
byopAddresses, err := m.proxyManager.GetActiveClusterAddressesForAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get BYOP cluster addresses: %w", err)
}
publicAddresses, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("get public cluster addresses: %w", err)
}
seen := make(map[string]struct{}, len(byopAddresses)+len(publicAddresses))
merged := make([]string, 0, len(byopAddresses)+len(publicAddresses))
for _, addr := range byopAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
for _, addr := range publicAddresses {
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
merged = append(merged, addr)
}
return merged, nil
}
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1

View File

@@ -0,0 +1,154 @@
package manager
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockProxyManager struct {
getActiveClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveClusterAddressesForAccountFunc func(ctx context.Context, accountID string) ([]string, error)
}
func (m *mockProxyManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveClusterAddressesFunc != nil {
return m.getActiveClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockProxyManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveClusterAddressesForAccountFunc != nil {
return m.getActiveClusterAddressesForAccountFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func TestGetClusterAllowList_BYOPMergedWithPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_DeduplicatesBYOPAndPublic(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"shared.example.com", "byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"shared.example.com", "eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"shared.example.com", "byop.example.com", "eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_NoBYOP_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_BYOPError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "BYOP cluster addresses")
}
func TestGetClusterAllowList_PublicError_ReturnsError(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, errors.New("db error")
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "public cluster addresses")
}
func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return []string{"eu.proxy.netbird.io"}, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestGetClusterAllowList_PublicEmpty_BYOPOnly(t *testing.T) {
pm := &mockProxyManager{
getActiveClusterAddressesForAccountFunc: func(_ context.Context, _ string) ([]string, error) {
return []string{"byop.example.com"}, nil
},
getActiveClusterAddressesFunc: func(_ context.Context) ([]string, error) {
return nil, nil
},
}
mgr := Manager{proxyManager: pm}
result, err := mgr.getClusterAllowList(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, []string{"byop.example.com"}, result)
}

View File

@@ -11,15 +11,19 @@ import (
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error)
Disconnect(ctx context.Context, proxyID, sessionID string) error
Heartbeat(ctx context.Context, p *Proxy) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.

View File

@@ -16,11 +16,16 @@ type store interface {
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error
}
// Manager handles all proxy operations
@@ -44,7 +49,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
func (m *Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
now := time.Now()
var caps proxy.Capabilities
if capabilities != nil {
@@ -55,9 +60,10 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
SessionID: sessionID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
AccountID: accountID,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
Status: proxy.StatusConnected,
Capabilities: caps,
}
@@ -77,7 +83,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
func (m *Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err
@@ -92,7 +98,7 @@ func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) erro
}
// Heartbeat updates the proxy's last seen timestamp.
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
func (m *Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
return err
@@ -104,7 +110,7 @@ func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
@@ -113,16 +119,6 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error
return addresses, nil
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
clusters, err := m.store.GetActiveProxyClusters(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
return nil, err
}
return clusters, nil
}
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
// supports custom ports. Returns nil when no proxy has reported capabilities.
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
@@ -142,10 +138,44 @@ func (m Manager) ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}
func (m *Manager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddressesForAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses for account %s: %v", accountID, err)
return nil, err
}
return addresses, nil
}
func (m *Manager) GetAccountProxy(ctx context.Context, accountID string) (*proxy.Proxy, error) {
return m.store.GetProxyByAccountID(ctx, accountID)
}
func (m *Manager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
return m.store.CountProxiesByAccountID(ctx, accountID)
}
func (m *Manager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
conflicting, err := m.store.IsClusterAddressConflicting(ctx, clusterAddress, accountID)
if err != nil {
return false, err
}
return !conflicting, nil
}
func (m *Manager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if err := m.store.DeleteAccountCluster(ctx, clusterAddress, accountID); err != nil {
log.WithContext(ctx).Errorf("failed to delete cluster %s for account %s: %v", clusterAddress, accountID, err)
return err
}
return nil
}

View File

@@ -0,0 +1,337 @@
package manager
import (
"context"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID, sessionID string) error
updateProxyHeartbeatFunc func(ctx context.Context, p *proxy.Proxy) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
getProxyByAccountIDFunc func(ctx context.Context, accountID string) (*proxy.Proxy, error)
countProxiesByAccountIDFunc func(ctx context.Context, accountID string) (int64, error)
isClusterAddressConflictingFunc func(ctx context.Context, clusterAddress, accountID string) (bool, error)
deleteAccountClusterFunc func(ctx context.Context, clusterAddress, accountID string) error
}
func (m *mockStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
if m.saveProxyFunc != nil {
return m.saveProxyFunc(ctx, p)
}
return nil
}
func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
if m.disconnectProxyFunc != nil {
return m.disconnectProxyFunc(ctx, proxyID, sessionID)
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, p)
}
return nil
}
func (m *mockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
if m.getActiveProxyClusterAddressesFunc != nil {
return m.getActiveProxyClusterAddressesFunc(ctx)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
if m.getActiveProxyClusterAddressesForAccFunc != nil {
return m.getActiveProxyClusterAddressesForAccFunc(ctx, accountID)
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)
}
return nil
}
func (m *mockStore) GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error) {
if m.getProxyByAccountIDFunc != nil {
return m.getProxyByAccountIDFunc(ctx, accountID)
}
return nil, fmt.Errorf("proxy not found for account %s", accountID)
}
func (m *mockStore) CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error) {
if m.countProxiesByAccountIDFunc != nil {
return m.countProxiesByAccountIDFunc(ctx, accountID)
}
return 0, nil
}
func (m *mockStore) IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error) {
if m.isClusterAddressConflictingFunc != nil {
return m.isClusterAddressConflictingFunc(ctx, clusterAddress, accountID)
}
return false, nil
}
func (m *mockStore) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
if m.deleteAccountClusterFunc != nil {
return m.deleteAccountClusterFunc(ctx, clusterAddress, accountID)
}
return nil
}
func (m *mockStore) GetClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterRequireSubdomain(_ context.Context, _ string) *bool {
return nil
}
func (m *mockStore) GetClusterSupportsCrowdSec(_ context.Context, _ string) *bool {
return nil
}
func newTestManager(s store) *Manager {
meter := noop.NewMeterProvider().Meter("test")
m, err := NewManager(s, meter)
if err != nil {
panic(err)
}
return m
}
func TestConnect_WithAccountID(t *testing.T) {
accountID := "acc-123"
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", &accountID, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Equal(t, "proxy-1", savedProxy.ID)
assert.Equal(t, "session-1", savedProxy.SessionID)
assert.Equal(t, "cluster.example.com", savedProxy.ClusterAddress)
assert.Equal(t, "10.0.0.1", savedProxy.IPAddress)
assert.Equal(t, &accountID, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
assert.NotNil(t, savedProxy.ConnectedAt)
}
func TestConnect_WithoutAccountID(t *testing.T) {
var savedProxy *proxy.Proxy
s := &mockStore{
saveProxyFunc: func(_ context.Context, p *proxy.Proxy) error {
savedProxy = p
return nil
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "eu.proxy.netbird.io", "10.0.0.1", nil, nil)
require.NoError(t, err)
require.NotNil(t, savedProxy)
assert.Nil(t, savedProxy.AccountID)
assert.Equal(t, proxy.StatusConnected, savedProxy.Status)
}
func TestConnect_StoreError(t *testing.T) {
s := &mockStore{
saveProxyFunc: func(_ context.Context, _ *proxy.Proxy) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
_, err := mgr.Connect(context.Background(), "proxy-1", "session-1", "cluster.example.com", "10.0.0.1", nil, nil)
assert.Error(t, err)
}
func TestIsClusterAddressAvailable(t *testing.T) {
tests := []struct {
name string
conflicting bool
storeErr error
wantResult bool
wantErr bool
}{
{
name: "available - no conflict",
conflicting: false,
wantResult: true,
},
{
name: "not available - conflict exists",
conflicting: true,
wantResult: false,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
isClusterAddressConflictingFunc: func(_ context.Context, _, _ string) (bool, error) {
return tt.conflicting, tt.storeErr
},
}
mgr := newTestManager(s)
result, err := mgr.IsClusterAddressAvailable(context.Background(), "cluster.example.com", "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
})
}
}
func TestCountAccountProxies(t *testing.T) {
tests := []struct {
name string
count int64
storeErr error
wantCount int64
wantErr bool
}{
{
name: "no proxies",
count: 0,
wantCount: 0,
},
{
name: "one proxy",
count: 1,
wantCount: 1,
},
{
name: "store error",
storeErr: errors.New("db error"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &mockStore{
countProxiesByAccountIDFunc: func(_ context.Context, _ string) (int64, error) {
return tt.count, tt.storeErr
},
}
mgr := newTestManager(s)
count, err := mgr.CountAccountProxies(context.Background(), "acc-123")
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantCount, count)
})
}
}
func TestGetAccountProxy(t *testing.T) {
accountID := "acc-123"
t.Run("found", func(t *testing.T) {
expected := &proxy.Proxy{
ID: "proxy-1",
ClusterAddress: "byop.example.com",
AccountID: &accountID,
Status: proxy.StatusConnected,
}
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, accID string) (*proxy.Proxy, error) {
assert.Equal(t, accountID, accID)
return expected, nil
},
}
mgr := newTestManager(s)
p, err := mgr.GetAccountProxy(context.Background(), accountID)
require.NoError(t, err)
assert.Equal(t, expected, p)
})
t.Run("not found", func(t *testing.T) {
s := &mockStore{
getProxyByAccountIDFunc: func(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, errors.New("not found")
},
}
mgr := newTestManager(s)
_, err := mgr.GetAccountProxy(context.Background(), accountID)
assert.Error(t, err)
})
}
func TestDeleteAccountCluster(t *testing.T) {
t.Run("success", func(t *testing.T) {
var deletedCluster, deletedAccount string
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, clusterAddress, accountID string) error {
deletedCluster = clusterAddress
deletedAccount = accountID
return nil
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
require.NoError(t, err)
assert.Equal(t, "cluster.example.com", deletedCluster)
assert.Equal(t, "acc-123", deletedAccount)
})
t.Run("store error", func(t *testing.T) {
s := &mockStore{
deleteAccountClusterFunc: func(_ context.Context, _, _ string) error {
return errors.New("db error")
},
}
mgr := newTestManager(s)
err := mgr.DeleteAccountCluster(context.Background(), "cluster.example.com", "acc-123")
assert.Error(t, err)
})
}
func TestGetActiveClusterAddressesForAccount(t *testing.T) {
expected := []string{"byop.example.com"}
s := &mockStore{
getActiveProxyClusterAddressesForAccFunc: func(_ context.Context, accID string) ([]string, error) {
assert.Equal(t, "acc-123", accID)
return expected, nil
},
}
mgr := newTestManager(s)
result, err := mgr.GetActiveClusterAddressesForAccount(context.Background(), "acc-123")
require.NoError(t, err)
assert.Equal(t, expected, result)
}

View File

@@ -93,18 +93,18 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, accountID *string, capabilities *Capabilities) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, accountID, capabilities)
}
// Disconnect mocks base method.
@@ -136,19 +136,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
ret0, _ := ret[0].([]Cluster)
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
// Heartbeat mocks base method.
@@ -165,6 +163,65 @@ func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
}
// GetAccountProxy mocks base method.
func (m *MockManager) GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountProxy", ctx, accountID)
ret0, _ := ret[0].(*Proxy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountProxy indicates an expected call of GetAccountProxy.
func (mr *MockManagerMockRecorder) GetAccountProxy(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountProxy", reflect.TypeOf((*MockManager)(nil).GetAccountProxy), ctx, accountID)
}
// CountAccountProxies mocks base method.
func (m *MockManager) CountAccountProxies(ctx context.Context, accountID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountAccountProxies", ctx, accountID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountAccountProxies indicates an expected call of CountAccountProxies.
func (mr *MockManagerMockRecorder) CountAccountProxies(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountProxies", reflect.TypeOf((*MockManager)(nil).CountAccountProxies), ctx, accountID)
}
// IsClusterAddressAvailable mocks base method.
func (m *MockManager) IsClusterAddressAvailable(ctx context.Context, clusterAddress, accountID string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClusterAddressAvailable", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsClusterAddressAvailable indicates an expected call of IsClusterAddressAvailable.
func (mr *MockManagerMockRecorder) IsClusterAddressAvailable(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClusterAddressAvailable", reflect.TypeOf((*MockManager)(nil).IsClusterAddressAvailable), ctx, clusterAddress, accountID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, clusterAddress, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, clusterAddress, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, clusterAddress, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, clusterAddress, accountID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller

View File

@@ -1,6 +1,13 @@
package proxy
import "time"
import (
"time"
)
const (
StatusConnected = "connected"
StatusDisconnected = "disconnected"
)
// Capabilities describes what a proxy can handle, as reported via gRPC.
// Nil fields mean the proxy never reported this capability.
@@ -21,6 +28,7 @@ type Proxy struct {
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
AccountID *string `gorm:"type:varchar(255);index:idx_proxy_account_id"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
@@ -36,6 +44,8 @@ func (Proxy) TableName() string {
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
ID string
Address string
ConnectedProxies int
SelfHosted bool
}

View File

@@ -0,0 +1,195 @@
package proxytoken
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
store store.Store
permissionsManager permissions.Manager
}
func RegisterEndpoints(s store.Store, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{store: s, permissionsManager: permissionsManager}
router.HandleFunc("/reverse-proxies/proxy-tokens", h.listTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens", h.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/proxy-tokens/{tokenId}", h.revokeToken).Methods("DELETE", "OPTIONS")
}
func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
var req api.ProxyTokenRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" || len(req.Name) > 255 {
util.WriteErrorResponse("name is required and must be at most 255 characters", http.StatusBadRequest, w)
return
}
var expiresIn time.Duration
if req.ExpiresIn != nil {
if *req.ExpiresIn < 0 {
util.WriteErrorResponse("expires_in must be non-negative", http.StatusBadRequest, w)
return
}
if *req.ExpiresIn > 0 {
expiresIn = time.Duration(*req.ExpiresIn) * time.Second
}
}
accountID := userAuth.AccountId
generated, err := types.CreateNewProxyAccessToken(req.Name, expiresIn, &accountID, userAuth.UserId)
if err != nil {
util.WriteErrorResponse("failed to generate token", http.StatusInternalServerError, w)
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
}
resp := make([]api.ProxyToken, 0, len(tokens))
for _, token := range tokens {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
}
if !ok {
util.WriteErrorResponse("permission denied", http.StatusForbidden, w)
return
}
tokenID := mux.Vars(r)["tokenId"]
if tokenID == "" {
util.WriteErrorResponse("token ID is required", http.StatusBadRequest, w)
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
} else {
util.WriteErrorResponse("failed to retrieve token", http.StatusInternalServerError, w)
}
return
}
if token.AccountID == nil || *token.AccountID != userAuth.AccountId {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {
resp := api.ProxyToken{
Id: token.ID,
Name: token.Name,
Revoked: token.Revoked,
}
if !token.CreatedAt.IsZero() {
resp.CreatedAt = token.CreatedAt
}
if token.ExpiresAt != nil {
resp.ExpiresAt = token.ExpiresAt
}
if token.LastUsed != nil {
resp.LastUsed = token.LastUsed
}
return resp
}
func toProxyTokenCreatedResponse(generated *types.ProxyAccessTokenGenerated) api.ProxyTokenCreated {
base := toProxyTokenResponse(&generated.ProxyAccessToken)
plainToken := string(generated.PlainToken)
return api.ProxyTokenCreated{
Id: base.Id,
Name: base.Name,
CreatedAt: base.CreatedAt,
ExpiresAt: base.ExpiresAt,
LastUsed: base.LastUsed,
Revoked: base.Revoked,
PlainToken: plainToken,
}
}

View File

@@ -0,0 +1,275 @@
package proxytoken
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
)
func authContext(accountID, userID string) context.Context {
return nbcontext.SetUserAuthInContext(context.Background(), auth.UserAuth{
AccountId: accountID,
UserId: userID,
})
}
func TestCreateToken_AccountScoped(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "my-token"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp api.ProxyTokenCreated
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
assert.NotEmpty(t, resp.PlainToken)
assert.Equal(t, "my-token", resp.Name)
assert.False(t, resp.Revoked)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.AccountID)
assert.Equal(t, accountID, *savedToken.AccountID)
assert.Equal(t, "user-1", savedToken.CreatedBy)
}
func TestCreateToken_WithExpiration(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
var savedToken *types.ProxyAccessToken
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().SaveProxyAccessToken(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, token *types.ProxyAccessToken) error {
savedToken = token
return nil
},
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
body := `{"name": "expiring-token", "expires_in": 3600}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
require.NotNil(t, savedToken)
require.NotNil(t, savedToken.ExpiresAt)
assert.True(t, savedToken.ExpiresAt.After(time.Now()))
}
func TestCreateToken_EmptyName(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": ""}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}
func TestCreateToken_PermissionDenied(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
h := &handler{
permissionsManager: permsMgr,
}
body := `{"name": "test"}`
req := httptest.NewRequest("POST", "/reverse-proxies/proxy-tokens", bytes.NewBufferString(body))
req = req.WithContext(authContext("acc-123", "user-1"))
w := httptest.NewRecorder()
h.createToken(w, req)
assert.Equal(t, http.StatusForbidden, w.Code)
}
func TestListTokens(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
now := time.Now()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokensByAccountID(gomock.Any(), store.LockingStrengthNone, accountID).Return([]*types.ProxyAccessToken{
{ID: "tok-1", Name: "token-1", AccountID: &accountID, CreatedAt: now, Revoked: false},
{ID: "tok-2", Name: "token-2", AccountID: &accountID, CreatedAt: now, Revoked: true},
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("GET", "/reverse-proxies/proxy-tokens", nil)
req = req.WithContext(authContext(accountID, "user-1"))
w := httptest.NewRecorder()
h.listTokens(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var resp []api.ProxyToken
require.NoError(t, json.NewDecoder(w.Body).Decode(&resp))
require.Len(t, resp, 2)
assert.Equal(t, "tok-1", resp[0].Id)
assert.False(t, resp[0].Revoked)
assert.Equal(t, "tok-2", resp[1].Id)
assert.True(t, resp[1].Revoked)
}
func TestRevokeToken_Success(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
accountID := "acc-123"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
Name: "test-token",
AccountID: &accountID,
}, nil)
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext(accountID, "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestRevokeToken_WrongAccount(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
otherAccount := "acc-other"
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: &otherAccount,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}
func TestRevokeToken_ManagementWideToken(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetProxyAccessTokenByID(gomock.Any(), store.LockingStrengthNone, "tok-1").Return(&types.ProxyAccessToken{
ID: "tok-1",
AccountID: nil,
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
h := &handler{
store: mockStore,
permissionsManager: permsMgr,
}
req := httptest.NewRequest("DELETE", "/reverse-proxies/proxy-tokens/tok-1", nil)
req = req.WithContext(authContext("acc-123", "user-1"))
req = mux.SetURLVars(req, map[string]string{"tokenId": "tok-1"})
w := httptest.NewRecorder()
h.revokeToken(w, req)
assert.Equal(t, http.StatusNotFound, w.Code)
}

View File

@@ -10,6 +10,7 @@ import (
type Manager interface {
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
@@ -28,4 +29,5 @@ type Manager interface {
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
}

View File

@@ -79,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteAccountCluster mocks base method.
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAccountCluster", ctx, accountID, userID, clusterAddress)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAccountCluster indicates an expected call of DeleteAccountCluster.
func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID, clusterAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -138,6 +152,21 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetServiceByDomain mocks base method.
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()

View File

@@ -35,6 +35,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/clusters/{clusterAddress}", h.deleteCluster).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
@@ -195,10 +196,33 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Id: c.ID,
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
SelfHosted: c.SelfHosted,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}
func (h *handler) deleteCluster(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusterAddress := mux.Vars(r)["clusterAddress"]
if clusterAddress == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "cluster address is required"), w)
return
}
if err := h.manager.DeleteAccountCluster(r.Context(), userAuth.AccountId, userAuth.UserId, clusterAddress); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -122,7 +122,21 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
return m.store.GetActiveProxyClusters(ctx, accountID)
}
// DeleteAccountCluster removes all proxy registrations for the given cluster address
// owned by the account.
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
return m.store.DeleteAccountCluster(ctx, clusterAddress, accountID)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
@@ -292,6 +306,10 @@ func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, clus
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
@@ -307,10 +325,6 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -421,6 +435,10 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
customPorts := m.clusterCustomPorts(ctx, svc)
if err := validateTargetReferences(ctx, m.store, accountID, svc.Targets); err != nil {
return err
}
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
@@ -434,10 +452,6 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
@@ -538,10 +552,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
svcForCaps.ProxyCluster = effectiveCluster
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
if err := validateTargetReferences(ctx, m.store, accountID, service.Targets); err != nil {
return nil, err
}
// Validate subdomain requirement *before* the transaction: the underlying
// capability lookup talks to the main DB pool, and SQLite's single-connection
// pool would self-deadlock if this ran while the tx already held the only
// connection.
if err := m.validateSubdomainRequirement(ctx, service.Domain, effectiveCluster); err != nil {
return nil, err
}
var updateInfo serviceUpdateInfo
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts, effectiveCluster)
})
return &updateInfo, err
@@ -571,7 +597,7 @@ func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string,
return existing.ProxyCluster, nil
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool, effectiveCluster string) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
@@ -589,17 +615,13 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
if err := m.handleDomainChange(ctx, transaction, service, effectiveCluster); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
return err
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
@@ -614,9 +636,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -624,20 +643,18 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
// handleDomainChange validates the new domain is free inside the transaction
// and applies the pre-resolved cluster (computed outside the tx by
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
// to the main DB pool and would self-deadlock under SQLite (max_open_conns=1)
// because the transaction already holds the only connection.
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, svc *service.Service, effectiveCluster string) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
svc.ProxyCluster = newCluster
}
if effectiveCluster != "" {
svc.ProxyCluster = effectiveCluster
}
return nil
}
@@ -986,6 +1003,10 @@ func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*
return services, nil
}
func (m *Manager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
return srv
}
@@ -714,7 +714,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1138,7 +1138,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)

View File

@@ -381,13 +381,14 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
}
// HTTP/HTTPS: build full URL
hostNoBrackets := strings.TrimSuffix(strings.TrimPrefix(target.Host, "["), "]")
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Host: bracketIPv6Host(hostNoBrackets),
Path: "/",
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
targetURL.Host = net.JoinHostPort(hostNoBrackets, strconv.FormatUint(uint64(target.Port), 10))
}
path := "/"
@@ -405,6 +406,19 @@ func (s *Service) buildPathMappings() []*proto.PathMapping {
return pathMappings
}
// bracketIPv6Host wraps host in square brackets when it is an IPv6 literal, as
// required for the Host field of net/url.URL (RFC 3986 §3.2.2). v4-mapped IPv6
// addresses are bracketed too since their textual form contains colons.
func bracketIPv6Host(host string) string {
if strings.HasPrefix(host, "[") {
return host
}
if addr, err := netip.ParseAddr(host); err == nil && addr.Is6() {
return "[" + host + "]"
}
return host
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:

View File

@@ -351,6 +351,83 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
{
name: "domain host without port is unchanged",
protocol: "http",
host: "example.com",
port: 0,
wantTarget: "http://example.com/",
},
{
name: "domain host with non-default port is unchanged",
protocol: "http",
host: "example.com",
port: 8080,
wantTarget: "http://example.com:8080/",
},
{
name: "ipv6 host without port is bracketed",
protocol: "http",
host: "fb00:cafe:1::3",
port: 0,
wantTarget: "http://[fb00:cafe:1::3]/",
},
{
name: "ipv6 host with default port omits port and brackets host",
protocol: "http",
host: "fb00:cafe:1::3",
port: 80,
wantTarget: "http://[fb00:cafe:1::3]/",
},
{
name: "ipv6 host with non-default port is bracketed",
protocol: "http",
host: "fb00:cafe:1::3",
port: 8080,
wantTarget: "http://[fb00:cafe:1::3]:8080/",
},
{
name: "ipv6 loopback without port is bracketed",
protocol: "http",
host: "::1",
port: 0,
wantTarget: "http://[::1]/",
},
{
name: "ipv6 host with 5-digit port is bracketed",
protocol: "http",
host: "fb00:cafe::1",
port: 18080,
wantTarget: "http://[fb00:cafe::1]:18080/",
},
{
name: "pre-bracketed ipv6 without port stays single-bracketed",
protocol: "http",
host: "[fb00:cafe::1]",
port: 0,
wantTarget: "http://[fb00:cafe::1]/",
},
{
name: "pre-bracketed ipv6 with port is not double-bracketed",
protocol: "http",
host: "[fb00:cafe::1]",
port: 8080,
wantTarget: "http://[fb00:cafe::1]:8080/",
},
{
name: "v4-mapped ipv6 host without port is bracketed",
protocol: "http",
host: "::ffff:10.0.0.1",
port: 0,
wantTarget: "http://[::ffff:10.0.0.1]/",
},
{
name: "full-form 8-group ipv6 without port is bracketed",
protocol: "http",
host: "fb00:cafe:1:0:0:0:0:3",
port: 0,
wantTarget: "http://[fb00:cafe:1:0:0:0:0:3]/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -193,7 +193,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -113,30 +114,47 @@ func (s *BaseServer) AccountManager() account.Manager {
})
}
func isMFAEnabledForAccount(accounts []*types.Account) bool {
if len(accounts) != 1 {
return false
}
settings := accounts[0].Settings
return settings != nil && settings.LocalMfaEnabled
}
func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
if embeddedEnabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
embeddedMgr, err := idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP service: %v", err)
}
return idpManager
if val := isMFAEnabledForAccount(s.Store().GetAllAccounts(context.Background())); val {
if err := embeddedMgr.SetMFAEnabled(context.Background(), val); err != nil {
log.Errorf("failed to set MFA enabled on embedded IDP: %v", err)
}
}
return embeddedMgr
}
// Fall back to external IdP service
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
idpManager, err := idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP service: %v", err)
}
return idpManager
}
return idpManager
return nil
})
}

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
@@ -50,6 +51,11 @@ type ProxyOIDCConfig struct {
KeysLocation string
}
// ProxyTokenChecker checks whether a proxy access token is still valid.
type ProxyTokenChecker interface {
IsProxyAccessTokenValid(ctx context.Context, tokenID string) (bool, error)
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -78,6 +84,9 @@ type ProxyServiceServer struct {
// Store for one-time authentication tokens
tokenStore *OneTimeTokenStore
// Checker for proxy access token validity
tokenChecker ProxyTokenChecker
// OIDC configuration for proxy authentication
oidcConfig ProxyOIDCConfig
@@ -123,6 +132,8 @@ type proxyConnection struct {
proxyID string
sessionID string
address string
accountID *string
tokenID string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.GetMappingUpdateResponse
@@ -130,8 +141,19 @@ type proxyConnection struct {
cancel context.CancelFunc
}
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
token := GetProxyTokenFromContext(ctx)
if token == nil || token.AccountID == nil {
return nil
}
if requestAccountID == "" || *token.AccountID != requestAccountID {
return status.Errorf(codes.PermissionDenied, "account-scoped token cannot access account %s", requestAccountID)
}
return nil
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -141,6 +163,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel,
}
@@ -200,6 +223,25 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
var accountID *string
token := GetProxyTokenFromContext(ctx)
if token != nil && token.AccountID != nil {
accountID = token.AccountID
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
if err != nil {
return status.Errorf(codes.Internal, "check cluster address: %v", err)
}
if !available {
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
}
}
var tokenID string
if token != nil {
tokenID = token.ID
}
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
@@ -217,6 +259,8 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
accountID: accountID,
tokenID: tokenID,
capabilities: req.GetCapabilities(),
stream: stream,
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
@@ -224,7 +268,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
caps = &proxy.Capabilities{
@@ -233,10 +276,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
cancel()
if accountID != nil {
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
}
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
@@ -266,6 +312,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"account_id": accountID,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
@@ -286,7 +333,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
go s.heartbeat(connCtx, proxyRecord)
go s.heartbeat(connCtx, conn, proxyRecord)
select {
case err := <-errChan:
@@ -298,8 +345,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
// heartbeat updates the proxy's last_seen timestamp every minute and
// disconnects the proxy if its access token has been revoked.
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
@@ -309,6 +357,19 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
}
if conn.tokenID != "" && s.tokenChecker != nil {
valid, err := s.tokenChecker.IsProxyAccessTokenValid(ctx, conn.tokenID)
if err != nil {
log.WithContext(ctx).Warnf("failed to check token validity for proxy %s: %v", conn.proxyID, err)
continue
}
if !valid {
log.WithContext(ctx).Warnf("proxy %s token revoked or expired, disconnecting", conn.proxyID)
conn.cancel()
return
}
}
case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return
@@ -316,8 +377,6 @@ func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only entries matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
if !isProxyAddressValid(conn.address) {
return fmt.Errorf("proxy address is invalid")
@@ -335,6 +394,13 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if end > len(mappings) {
end = len(mappings)
}
for _, m := range mappings[i:end] {
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
if err != nil {
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
}
m.AuthToken = token
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
@@ -355,23 +421,25 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
}
func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *proxyConnection) ([]*proto.ProxyMapping, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
var services []*rpservice.Service
var err error
if conn.accountID != nil {
services, err = s.serviceManager.GetAccountServices(ctx, *conn.accountID)
} else {
services, err = s.serviceManager.GetGlobalServices(ctx)
}
if err != nil {
return nil, fmt.Errorf("get services from store: %w", err)
}
oidcCfg := s.GetOIDCValidationConfig()
var mappings []*proto.ProxyMapping
for _, service := range services {
if !service.Enabled || service.ProxyCluster == "" || service.ProxyCluster != conn.address {
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil {
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
m := service.ToProtoMapping(rpservice.Create, "", oidcCfg)
if !proxyAcceptsMapping(conn, m) {
continue
}
@@ -380,8 +448,14 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
return mappings, nil
}
// isProxyAddressValid validates a proxy address
// isProxyAddressValid validates a proxy address (domain name or IP address)
func isProxyAddressValid(addr string) bool {
if addr == "" {
return false
}
if net.ParseIP(addr) != nil {
return true
}
_, err := domain.ValidateDomains([]string{addr})
return err == nil
}
@@ -405,6 +479,10 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
accessLog := req.GetLog()
if err := enforceAccountScope(ctx, accessLog.GetAccountId()); err != nil {
return nil, err
}
fields := log.Fields{
"service_id": accessLog.GetServiceId(),
"account_id": accessLog.GetAccountId(),
@@ -442,11 +520,32 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
// BYOP proxies only receive updates for their own account's services.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
updateAccountIDs := make(map[string]struct{})
for _, m := range update.Mapping {
if m.AccountId != "" {
updateAccountIDs[m.AccountId] = struct{}{}
}
}
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
connUpdate := update
if conn.accountID != nil && len(updateAccountIDs) > 0 {
if _, ok := updateAccountIDs[*conn.accountID]; !ok {
return true
}
filtered := filterMappingsForAccount(update.Mapping, *conn.accountID)
if len(filtered) == 0 {
return true
}
connUpdate = &proto.GetMappingUpdateResponse{
Mapping: filtered,
InitialSyncComplete: update.InitialSyncComplete,
}
}
resp := s.perProxyMessage(connUpdate, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
@@ -463,6 +562,26 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
})
}
// ForceDisconnect cancels the gRPC stream for a connected proxy, causing it to disconnect.
func (s *ProxyServiceServer) ForceDisconnect(proxyID string) {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
conn.cancel()
s.connectedProxies.Delete(proxyID)
log.WithFields(log.Fields{"proxyID": proxyID}).Info("force disconnected proxy")
}
}
func filterMappingsForAccount(mappings []*proto.ProxyMapping, accountID string) []*proto.ProxyMapping {
var filtered []*proto.ProxyMapping
for _, m := range mappings {
if m.AccountId == accountID {
filtered = append(filtered, m)
}
}
return filtered
}
// GetConnectedProxies returns a list of connected proxy IDs
func (s *ProxyServiceServer) GetConnectedProxies() []string {
var proxies []string
@@ -531,6 +650,9 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
continue
}
conn := connVal.(*proxyConnection)
if conn.accountID != nil && update.AccountId != "" && *conn.accountID != update.AccountId {
continue
}
if !proxyAcceptsMapping(conn, update) {
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
continue
@@ -618,6 +740,10 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
@@ -737,6 +863,10 @@ func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authentic
// SendStatusUpdate handles status updates from proxy clients.
func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.SendStatusUpdateRequest) (*proto.SendStatusUpdateResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
accountID := req.GetAccountId()
serviceID := req.GetServiceId()
protoStatus := req.GetStatus()
@@ -807,6 +937,10 @@ func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
// CreateProxyPeer handles proxy peer creation with one-time token authentication
func (s *ProxyServiceServer) CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest) (*proto.CreateProxyPeerResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
serviceID := req.GetServiceId()
accountID := req.GetAccountId()
token := req.GetToken()
@@ -861,6 +995,10 @@ func strPtr(s string) *string {
}
func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCURLRequest) (*proto.GetOIDCURLResponse, error) {
if err := enforceAccountScope(ctx, req.GetAccountId()); err != nil {
return nil, err
}
redirectURL, err := url.Parse(req.GetRedirectUrl())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
@@ -989,21 +1127,9 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.serviceManager.GetGlobalServices(ctx)
service, err := s.getServiceByDomain(ctx, domain)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
break
}
}
if service == nil {
return "", fmt.Errorf("service not found for domain: %s", domain)
return "", fmt.Errorf("service not found for domain %s: %w", domain, err)
}
if service.SessionPrivateKey == "" {
@@ -1101,6 +1227,10 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
if err := enforceAccountScope(ctx, service.AccountID); err != nil {
return nil, err
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(service.SessionPublicKey)
if err != nil {
log.WithFields(log.Fields{
@@ -1184,18 +1314,7 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
for _, service := range services {
if service.Domain == domain {
return service, nil
}
}
return nil, fmt.Errorf("service not found for domain: %s", domain)
return s.serviceManager.GetServiceByDomain(ctx, domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {

View File

@@ -0,0 +1,29 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsProxyAddressValid(t *testing.T) {
tests := []struct {
name string
addr string
valid bool
}{
{name: "valid domain", addr: "eu.proxy.netbird.io", valid: true},
{name: "valid subdomain", addr: "byop.proxy.example.com", valid: true},
{name: "valid IPv4", addr: "10.0.0.1", valid: true},
{name: "valid IPv4 public", addr: "203.0.113.10", valid: true},
{name: "valid IPv6", addr: "::1", valid: true},
{name: "valid IPv6 full", addr: "2001:db8::1", valid: true},
{name: "empty string", addr: "", valid: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.valid, isProxyAddressValid(tt.addr))
})
}
}

View File

@@ -153,9 +153,6 @@ func (i *proxyAuthInterceptor) doValidateProxyToken(ctx context.Context) (*types
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// TODO: Enforce AccountID scope for "bring your own proxy" feature.
// Currently tokens are management-wide; AccountID field is reserved for future use.
if !token.IsValid() {
return nil, status.Errorf(codes.Unauthenticated, "token expired or revoked")
}

View File

@@ -53,6 +53,10 @@ func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID,
return nil
}
func (m *mockReverseProxyManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
@@ -91,6 +95,20 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain string) (*service.Service, error) {
if m.err != nil {
return nil, m.err
}
for _, services := range m.proxiesByAccount {
for _, svc := range services {
if svc.Domain == domain {
return svc, nil
}
}
}
return nil, errors.New("service not found for domain: " + domain)
}
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
@@ -172,3 +173,55 @@ func TestSendSnapshot_EmptySnapshot(t *testing.T) {
assert.Empty(t, stream.messages[0].Mapping)
assert.True(t, stream.messages[0].InitialSyncComplete)
}
type hookingStream struct {
grpc.ServerStream
onSend func(*proto.GetMappingUpdateResponse)
}
func (s *hookingStream) Send(m *proto.GetMappingUpdateResponse) error {
if s.onSend != nil {
s.onSend(m)
}
return nil
}
func (s *hookingStream) Context() context.Context { return context.Background() }
func (s *hookingStream) SetHeader(metadata.MD) error { return nil }
func (s *hookingStream) SendHeader(metadata.MD) error { return nil }
func (s *hookingStream) SetTrailer(metadata.MD) {}
func (s *hookingStream) SendMsg(any) error { return nil }
func (s *hookingStream) RecvMsg(any) error { return nil }
func TestSendSnapshot_TokensRemainValidUnderSlowSend(t *testing.T) {
const cluster = "cluster.example.com"
const batchSize = 2
const totalServices = 6
const ttl = 100 * time.Millisecond
const sendDelay = 200 * time.Millisecond
ctrl := gomock.NewController(t)
mgr := rpservice.NewMockManager(ctrl)
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
s := newSnapshotTestServer(t, batchSize)
s.serviceManager = mgr
s.tokenTTL = ttl
var validateErrs []error
stream := &hookingStream{
onSend: func(resp *proto.GetMappingUpdateResponse) {
for _, m := range resp.Mapping {
if err := s.tokenStore.ValidateAndConsume(m.AuthToken, m.AccountId, m.Id); err != nil {
validateErrs = append(validateErrs, fmt.Errorf("svc %s: %w", m.Id, err))
}
}
time.Sleep(sendDelay)
},
}
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
require.NoError(t, s.sendSnapshot(context.Background(), conn))
require.Empty(t, validateErrs,
"tokens must remain valid even when batches are sent slowly: lazy per-batch generation guarantees freshness")
}

View File

@@ -12,9 +12,12 @@ import (
cachestore "github.com/eko/gocache/lib/v4/store"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
grpcstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -316,6 +319,58 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
assert.Contains(t, err.Error(), "invalid state format")
}
func scopedCtx(accountID string) context.Context {
token := &types.ProxyAccessToken{
ID: "token-1",
AccountID: &accountID,
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func globalCtx() context.Context {
token := &types.ProxyAccessToken{
ID: "token-global",
}
return context.WithValue(context.Background(), ProxyTokenContextKey, token)
}
func TestEnforceAccountScope_AllowsMatchingAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-1")
assert.NoError(t, err)
}
func TestEnforceAccountScope_BlocksMismatchedAccount(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "acc-2")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_BlocksEmptyRequestAccountID(t *testing.T) {
err := enforceAccountScope(scopedCtx("acc-1"), "")
require.Error(t, err)
st, ok := grpcstatus.FromError(err)
require.True(t, ok)
assert.Equal(t, codes.PermissionDenied, st.Code())
}
func TestEnforceAccountScope_AllowsGlobalToken(t *testing.T) {
err := enforceAccountScope(globalCtx(), "acc-1")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "acc-2")
assert.NoError(t, err)
err = enforceAccountScope(globalCtx(), "")
assert.NoError(t, err)
}
func TestEnforceAccountScope_AllowsNoTokenInContext(t *testing.T) {
err := enforceAccountScope(context.Background(), "acc-1")
assert.NoError(t, err)
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
ctx := context.Background()
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))

View File

@@ -522,10 +522,11 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
uncanceledCTX := context.WithoutCancel(ctx)
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
defer unlock()
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -318,21 +318,33 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
return m.store.GetServiceByDomain(ctx, domain)
}
func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *testValidateSessionServiceManager) DeleteAccountCluster(_ context.Context, _, _, _ string) error {
return nil
}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error {
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _, _ string, _ *string, _ *proxy.Capabilities) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ *proxy.Proxy) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) DeleteAccountCluster(_ context.Context, _, _ string) error {
return nil
}
@@ -340,6 +352,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddressesForAccount(_ context.Context, _ string) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
@@ -348,6 +364,22 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time
return nil
}
func (m *testValidateSessionProxyManager) GetAccountProxy(_ context.Context, _ string) (*proxy.Proxy, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CountAccountProxies(_ context.Context, _ string) (int64, error) {
return 0, nil
}
func (m *testValidateSessionProxyManager) IsClusterAddressAvailable(_ context.Context, _, _ string) (bool, error) {
return true, nil
}
func (m *testValidateSessionProxyManager) DeleteProxy(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool {
return nil
}

View File

@@ -291,10 +291,15 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.NewPermissionDeniedError()
}
// Canonicalize the incoming range so a caller-supplied prefix with host bits
// (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net.
newSettings.NetworkRange = newSettings.NetworkRange.Masked()
var oldSettings *types.Settings
var updateAccountPeers bool
var groupChangesAffectPeers bool
var reloadReverseProxy bool
var effectiveOldNetworkRange netip.Prefix
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool
@@ -308,6 +313,16 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
// No lock: the transaction already holds Settings(Update), and network.Net is
// only mutated by reallocateAccountPeerIPs, which is reachable only through
// this same code path. A Share lock here would extend an unnecessary row lock
// and complicate ordering against updatePeerIPv6InTransaction.
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("get account network: %w", err)
}
effectiveOldNetworkRange = prefixFromIPNet(network.Net)
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
@@ -321,7 +336,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
@@ -386,6 +401,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
if err = am.handleLocalMfaSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
eventMeta := map[string]any{
"old_dns_domain": oldSettings.DNSDomain,
@@ -393,9 +411,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
}
if oldSettings.NetworkRange != newSettings.NetworkRange {
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
eventMeta := map[string]any{
"old_network_range": oldSettings.NetworkRange.String(),
"old_network_range": effectiveOldNetworkRange.String(),
"new_network_range": newSettings.NetworkRange.String(),
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
@@ -440,6 +458,22 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool {
return !slices.Equal(oldGroups, newGroups)
}
// prefixFromIPNet returns the overlay prefix actually allocated on the account
// network, or an invalid prefix if none is set. Settings.NetworkRange is a
// user-facing override that is empty on legacy accounts, so the effective
// range must be read from network.Net to compare against an incoming update.
func prefixFromIPNet(ipNet net.IPNet) netip.Prefix {
if ipNet.IP == nil {
return netip.Prefix{}
}
addr, ok := netip.AddrFromSlice(ipNet.IP)
if !ok {
return netip.Prefix{}
}
ones, _ := ipNet.Mask.Size()
return netip.PrefixFrom(addr.Unmap(), ones)
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
@@ -602,6 +636,29 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
return nil
}
func (am *DefaultAccountManager) handleLocalMfaSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if oldSettings.LocalMfaEnabled == newSettings.LocalMfaEnabled {
return nil
}
embeddedIdp, ok := am.idpManager.(*idp.EmbeddedIdPManager)
if !ok {
return nil
}
if err := embeddedIdp.SetMFAEnabled(ctx, newSettings.LocalMfaEnabled); err != nil {
return fmt.Errorf("failed to toggle MFA: %w", err)
}
if newSettings.LocalMfaEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLocalMfaEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLocalMfaDisabled, nil)
}
return nil
}
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) {
//nolint
@@ -2461,6 +2518,18 @@ func (am *DefaultAccountManager) buildIPv6AllowedPeers(ctx context.Context, tran
allowedPeers[peerID] = struct{}{}
}
}
// Embedded proxy peers sit outside regular group membership but must
// participate in any v6-enabled overlay to reach v6-only peers.
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return nil, fmt.Errorf("get peers: %w", err)
}
for _, p := range peers {
if p.ProxyMeta.Embedded {
allowedPeers[p.ID] = struct{}{}
}
}
return allowedPeers, nil
}

View File

@@ -3113,7 +3113,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
@@ -3970,6 +3970,96 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
}
}
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against
// peer IP reallocation when a settings update carries the network range that is already
// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net
// holds the actual allocated overlay; the dashboard backfills the GET response from
// network.Net and echoes the value back on PUT, so the diff must be against the
// effective range to avoid renumbering every peer on an unrelated settings change.
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
ctx := context.Background()
settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
require.NoError(t, err)
require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset")
network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id)
require.NoError(t, err)
require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated")
addr, ok := netip.AddrFromSlice(network.Net.IP)
require.True(t, ok)
ones, _ := network.Net.Mask.Size()
effective := netip.PrefixFrom(addr.Unmap(), ones)
require.True(t, effective.IsValid())
before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP}
// Round-trip the effective range as if the dashboard echoed back the GET-backfilled value.
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
NetworkRange: effective,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
require.Len(t, peers, len(before))
for _, p := range peers {
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID)
}
// Carrying the same range with host bits set must also be a no-op once canonicalized.
hostBitsForm := netip.PrefixFrom(peer1.IP, ones)
require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking")
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
NetworkRange: hostBitsForm,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
for _, p := range peers {
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID)
}
// Omitting NetworkRange (invalid prefix) must also be a no-op.
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
for _, p := range peers {
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID)
}
// Sanity: an actually different range still triggers reallocation.
newRange := netip.MustParsePrefix("100.99.0.0/16")
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
NetworkRange: newRange,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
for _, p := range peers {
assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP)
assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID)
}
}
func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
ctx := context.Background()

View File

@@ -236,6 +236,11 @@ const (
// AccountIPv6Disabled indicates that a user disabled IPv6 overlay for the account
AccountIPv6Disabled Activity = 122
// AccountLocalMfaEnabled indicates that a user enabled TOTP MFA for local users
AccountLocalMfaEnabled Activity = 123
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
AccountLocalMfaDisabled Activity = 124
AccountDeleted Activity = 99999
)
@@ -386,6 +391,9 @@ var activityMap = map[Activity]Code{
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -144,6 +145,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}
proxytoken.RegisterEndpoints(accountManager.GetStore(), permissionsManager, router)
// Register OAuth callback handler for proxy authentication
if proxyGRPCServer != nil {
oauthHandler := proxy.NewAuthCallbackHandler(proxyGRPCServer, trustedHTTPProxies)

View File

@@ -277,6 +277,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
if req.Settings.AutoUpdateAlways != nil {
returnSettings.AutoUpdateAlways = *req.Settings.AutoUpdateAlways
}
if req.Settings.LocalMfaEnabled != nil {
returnSettings.LocalMfaEnabled = *req.Settings.LocalMfaEnabled
}
if req.Settings.Ipv6EnabledGroups != nil {
returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups
}
@@ -412,6 +415,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
Ipv6EnabledGroups: &settings.IPv6EnabledGroups,
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
LocalAuthDisabled: &settings.LocalAuthDisabled,
LocalMfaEnabled: &settings.LocalMfaEnabled,
}
if settings.NetworkRange.IsValid() {

Some files were not shown because too many files have changed in this diff Show More