Compare commits

..

2 Commits

Author SHA1 Message Date
Viktor Liu
036a3020fe Batch wireguard update operations 2025-07-22 14:44:26 +02:00
Zoltan Papp
86c16cf651 [server, relay] Fix/relay race disconnection (#4174)
Avoid invalid disconnection notifications in case the closed race dials.
In this PR resolve multiple race condition questions. Easier to understand the fix based on commit by commit.

- Remove store dependency from notifier
- Enforce the notification orders
- Fix invalid disconnection notification
- Ensure the order of the events on the consumer side
2025-07-21 19:58:17 +02:00
29 changed files with 604 additions and 127 deletions

View File

@@ -211,7 +211,11 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
include:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Install Go
@@ -251,9 +255,9 @@ jobs:
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
go test ${{ matrix.raceFlag }} \
-exec 'sudo' \
-timeout 10m ./signal/...
-timeout 10m ./relay/...
test_signal:
name: "Signal / Unit"

338
client/iface/batcher.go Normal file
View File

@@ -0,0 +1,338 @@
package iface
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"os"
"strconv"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
// DefaultBatchFlushInterval is the default maximum time to wait before flushing batched operations
DefaultBatchFlushInterval = 300 * time.Millisecond
// DefaultBatchSizeThreshold is the default number of operations to trigger an immediate flush
DefaultBatchSizeThreshold = 100
// AllowedIPOpAdd represents an add operation
AllowedIPOpAdd = "add"
// AllowedIPOpRemove represents a remove operation
AllowedIPOpRemove = "remove"
EnvDisableWGBatching = "NB_DISABLE_WG_BATCHING"
EnvWGBatchFlushIntervalMS = "NB_WG_BATCH_FLUSH_INTERVAL_MS"
EnvWGBatchSizeThreshold = "NB_WG_BATCH_SIZE_THRESHOLD"
)
// AllowedIPOperation represents a pending allowed IP operation
type AllowedIPOperation struct {
PeerKey string
Prefix netip.Prefix
Operation string
}
// PeerUpdateOperation represents a pending peer update operation
type PeerUpdateOperation struct {
PeerKey string
AllowedIPs []netip.Prefix
KeepAlive time.Duration
Endpoint *net.UDPAddr
PreSharedKey *wgtypes.Key
}
// WGBatcher batches WireGuard configuration updates to reduce syscall overhead
type WGBatcher struct {
configurer device.WGConfigurer
mu sync.Mutex
allowedIPOps []AllowedIPOperation
peerUpdates map[string]*PeerUpdateOperation
flushTimer *time.Timer
flushChan chan struct{}
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
batchFlushInterval time.Duration
batchSizeThreshold int
}
// NewWGBatcher creates a new WireGuard operation batcher
func NewWGBatcher(configurer device.WGConfigurer) *WGBatcher {
if os.Getenv(EnvDisableWGBatching) != "" {
log.Infof("WireGuard allowed IP batching disabled via %s", EnvDisableWGBatching)
return nil
}
flushInterval := DefaultBatchFlushInterval
sizeThreshold := DefaultBatchSizeThreshold
if intervalMs := os.Getenv(EnvWGBatchFlushIntervalMS); intervalMs != "" {
if ms, err := strconv.Atoi(intervalMs); err == nil && ms > 0 {
flushInterval = time.Duration(ms) * time.Millisecond
log.Infof("WireGuard batch flush interval set to %v", flushInterval)
}
}
if threshold := os.Getenv(EnvWGBatchSizeThreshold); threshold != "" {
if size, err := strconv.Atoi(threshold); err == nil && size > 0 {
sizeThreshold = size
log.Infof("WireGuard batch size threshold set to %d", sizeThreshold)
}
}
log.Info("WireGuard allowed IP batching enabled")
ctx, cancel := context.WithCancel(context.Background())
b := &WGBatcher{
configurer: configurer,
peerUpdates: make(map[string]*PeerUpdateOperation),
flushChan: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
batchFlushInterval: flushInterval,
batchSizeThreshold: sizeThreshold,
}
b.wg.Add(1)
go b.flushLoop()
return b
}
// Close stops the batcher and flushes any pending operations
func (b *WGBatcher) Close() error {
b.mu.Lock()
if b.flushTimer != nil {
b.flushTimer.Stop()
}
b.mu.Unlock()
b.cancel()
if err := b.Flush(); err != nil {
log.Errorf("failed to flush pending operations on close: %v", err)
}
b.wg.Wait()
return nil
}
// UpdatePeer batches a peer update operation
func (b *WGBatcher) UpdatePeer(peerKey string, allowedIPs []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
b.mu.Lock()
defer b.mu.Unlock()
b.peerUpdates[peerKey] = &PeerUpdateOperation{
PeerKey: peerKey,
AllowedIPs: allowedIPs,
KeepAlive: keepAlive,
Endpoint: endpoint,
PreSharedKey: preSharedKey,
}
b.scheduleFlush()
return nil
}
// AddAllowedIP batches an allowed IP addition
func (b *WGBatcher) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
b.mu.Lock()
defer b.mu.Unlock()
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
PeerKey: peerKey,
Prefix: allowedIP,
Operation: AllowedIPOpAdd,
})
b.scheduleFlush()
return nil
}
// RemoveAllowedIP batches an allowed IP removal
func (b *WGBatcher) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
b.mu.Lock()
defer b.mu.Unlock()
b.allowedIPOps = append(b.allowedIPOps, AllowedIPOperation{
PeerKey: peerKey,
Prefix: allowedIP,
Operation: AllowedIPOpRemove,
})
b.scheduleFlush()
return nil
}
// Flush immediately processes all batched operations
func (b *WGBatcher) Flush() error {
b.mu.Lock()
if b.flushTimer != nil {
b.flushTimer.Stop()
b.flushTimer = nil
}
peerUpdates := b.peerUpdates
allowedIPOps := b.allowedIPOps
b.peerUpdates = make(map[string]*PeerUpdateOperation)
b.allowedIPOps = nil
b.mu.Unlock()
return b.processBatch(peerUpdates, allowedIPOps)
}
// scheduleFlush schedules a batch flush if not already scheduled
func (b *WGBatcher) scheduleFlush() {
shouldFlushNow := len(b.allowedIPOps)+len(b.peerUpdates) >= b.batchSizeThreshold
if shouldFlushNow {
select {
case b.flushChan <- struct{}{}:
default:
}
return
}
if b.flushTimer == nil {
b.flushTimer = time.AfterFunc(b.batchFlushInterval, func() {
select {
case b.flushChan <- struct{}{}:
default:
}
})
}
}
// flushLoop handles periodic flushing of batched operations
func (b *WGBatcher) flushLoop() {
defer b.wg.Done()
for {
select {
case <-b.flushChan:
if err := b.Flush(); err != nil {
log.Errorf("Error flushing WireGuard operations: %v", err)
}
case <-b.ctx.Done():
return
}
}
}
// processBatch processes a batch of operations
func (b *WGBatcher) processBatch(peerUpdates map[string]*PeerUpdateOperation, allowedIPOps []AllowedIPOperation) error {
if len(peerUpdates) == 0 && len(allowedIPOps) == 0 {
return nil
}
start := time.Now()
defer func() {
duration := time.Since(start)
log.Debugf("Processed batch of %d peer updates and %d allowed IP operations in %v",
len(peerUpdates), len(allowedIPOps), duration)
}()
var merr *multierror.Error
if err := b.processPeerUpdates(peerUpdates); err != nil {
merr = multierror.Append(merr, err)
}
if err := b.processAllowedIPOps(allowedIPOps); err != nil {
merr = multierror.Append(merr, err)
}
return nberrors.FormatErrorOrNil(merr)
}
// processPeerUpdates processes peer update operations
func (b *WGBatcher) processPeerUpdates(peerUpdates map[string]*PeerUpdateOperation) error {
var merr *multierror.Error
for _, update := range peerUpdates {
if err := b.configurer.UpdatePeer(
update.PeerKey,
update.AllowedIPs,
update.KeepAlive,
update.Endpoint,
update.PreSharedKey,
); err != nil {
merr = multierror.Append(merr, fmt.Errorf("update peer %s: %w", update.PeerKey, err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// processAllowedIPOps processes allowed IP add/remove operations
func (b *WGBatcher) processAllowedIPOps(allowedIPOps []AllowedIPOperation) error {
peerChanges := b.groupAllowedIPChanges(allowedIPOps)
return b.applyAllowedIPChanges(peerChanges)
}
// groupAllowedIPChanges groups allowed IP operations by peer
func (b *WGBatcher) groupAllowedIPChanges(allowedIPOps []AllowedIPOperation) map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
} {
peerChanges := make(map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
})
for _, op := range allowedIPOps {
changes := peerChanges[op.PeerKey]
if op.Operation == AllowedIPOpAdd {
changes.toAdd = append(changes.toAdd, op.Prefix)
} else {
changes.toRemove = append(changes.toRemove, op.Prefix)
}
peerChanges[op.PeerKey] = changes
}
return peerChanges
}
// applyAllowedIPChanges applies allowed IP changes for each peer
func (b *WGBatcher) applyAllowedIPChanges(peerChanges map[string]struct {
toAdd []netip.Prefix
toRemove []netip.Prefix
}) error {
var merr *multierror.Error
for peerKey, changes := range peerChanges {
for _, prefix := range changes.toRemove {
if err := b.configurer.RemoveAllowedIP(peerKey, prefix); err != nil {
if errors.Is(err, configurer.ErrPeerNotFound) || errors.Is(err, configurer.ErrAllowedIPNotFound) {
log.Debugf("remove allowed IP %s for peer %s: %v", prefix, peerKey, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s for peer %s: %w", prefix, peerKey, err))
}
}
}
for _, prefix := range changes.toAdd {
if err := b.configurer.AddAllowedIP(peerKey, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s for peer %s: %w", prefix, peerKey, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -4,4 +4,4 @@ package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}
}

View File

@@ -3,4 +3,4 @@
package configurer
// WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "nb0"
const WgInterfaceDefault = "wt0"

View File

@@ -59,6 +59,7 @@ type WGIface struct {
mu sync.Mutex
configurer device.WGConfigurer
batcher *WGBatcher
filter device.PacketFilter
wgProxyFactory wgProxyFactory
}
@@ -128,6 +129,12 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
}
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
if endpoint != nil && w.batcher != nil {
if err := w.batcher.Flush(); err != nil {
log.Warnf("failed to flush batched operations: %v", err)
}
}
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
@@ -152,6 +159,10 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
}
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
if w.batcher != nil {
return w.batcher.AddAllowedIP(peerKey, allowedIP)
}
return w.configurer.AddAllowedIP(peerKey, allowedIP)
}
@@ -164,6 +175,10 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
}
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
if w.batcher != nil {
return w.batcher.RemoveAllowedIP(peerKey, allowedIP)
}
return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
}
@@ -174,6 +189,12 @@ func (w *WGIface) Close() error {
var result *multierror.Error
if w.batcher != nil {
if err := w.batcher.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close WireGuard batcher: %w", err))
}
}
if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}

View File

@@ -17,6 +17,7 @@ func (w *WGIface) Create() error {
}
w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil
}

View File

@@ -1,8 +1,6 @@
package iface
import (
"fmt"
)
import "fmt"
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
@@ -15,6 +13,7 @@ func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []s
return err
}
w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil
}

View File

@@ -29,6 +29,7 @@ func (w *WGIface) Create() error {
return err
}
w.configurer = cfgr
w.batcher = NewWGBatcher(cfgr)
return nil
}

View File

@@ -39,7 +39,7 @@ const (
)
var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}

View File

@@ -1393,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("nb%d", i)
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i

View File

@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
}
func (m *mockIFaceMapper) Name() string {
return "nb0"
return "wt0"
}
func (m *mockIFaceMapper) Address() wgaddr.Address {

View File

@@ -24,7 +24,7 @@ type WorkerRelay struct {
isController bool
config ConnConfig
conn *Conn
relayManager relayClient.ManagerService
relayManager *relayClient.Manager
relayedConn net.Conn
relayLock sync.Mutex
@@ -34,7 +34,7 @@ type WorkerRelay struct {
wgWatcher *WGWatcher
}
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{
peerCtx: ctx,
log: log,

View File

@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
IP: wgNetwork.Addr(),
Network: wgNetwork,
},
name: "nb0",
name: "wt0",
}
sysOps := &SysOps{

View File

@@ -292,7 +292,7 @@ func (c *Client) Close() error {
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, c.connectionURL, quic.Dialer{}, ws.Dialer{})
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return nil, err

View File

@@ -572,10 +572,14 @@ func TestCloseByServer(t *testing.T) {
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect(ctx)
if err != nil {
if err = relayClient.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) {
@@ -591,7 +595,7 @@ func TestCloseByServer(t *testing.T) {
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Fatalf("timeout waiting for client to disconnect")
log.Errorf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn(ctx, "bob")

View File

@@ -9,8 +9,8 @@ import (
log "github.com/sirupsen/logrus"
)
var (
connectionTimeout = 30 * time.Second
const (
DefaultConnectionTimeout = 30 * time.Second
)
type DialeFn interface {
@@ -25,16 +25,18 @@ type dialResult struct {
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
log *log.Entry
serverURL string
dialerFns []DialeFn
connectionTimeout time.Duration
}
func NewRaceDial(log *log.Entry, serverURL string, dialerFns ...DialeFn) *RaceDial {
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
}
}
@@ -58,7 +60,7 @@ func (r *RaceDial) Dial() (net.Conn, error) {
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, connectionTimeout)
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())

View File

@@ -77,7 +77,7 @@ func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, serverURL)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
@@ -103,7 +103,7 @@ func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
protocolStr: proto,
}
rd := NewRaceDial(logger, serverURL, mockDialer)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
@@ -136,7 +136,7 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
protocolStr: "proto2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
@@ -144,13 +144,13 @@ func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
_ = conn.Close()
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
connectionTimeout = 3 * time.Second
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
@@ -159,7 +159,7 @@ func TestRaceDialTimeout(t *testing.T) {
protocolStr: "proto1",
}
rd := NewRaceDial(logger, serverURL, mockDialer)
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
@@ -187,7 +187,7 @@ func TestRaceDialAllDialersFail(t *testing.T) {
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
@@ -229,7 +229,7 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
protocolStr: proto2,
}
rd := NewRaceDial(logger, serverURL, mockDialer1, mockDialer2)
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)

View File

@@ -8,7 +8,8 @@ import (
log "github.com/sirupsen/logrus"
)
var (
const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second
)

View File

@@ -39,17 +39,6 @@ func NewRelayTrack() *RelayTrack {
type OnServerCloseListener func()
// ManagerService is the interface for the relay manager.
type ManagerService interface {
Serve() error
OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error)
AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
RelayInstanceAddress() (string, error)
ServerURLs() []string
HasRelayAddress() bool
UpdateToken(token *relayAuth.Token) error
}
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a

View File

@@ -13,7 +13,9 @@ import (
)
func TestEmptyURL(t *testing.T) {
mgr := NewManager(context.Background(), nil, "alice")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
@@ -216,9 +218,11 @@ func TestForeginConnClose(t *testing.T) {
}
}
func TestForeginAutoClose(t *testing.T) {
func TestForeignAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
@@ -284,16 +288,35 @@ func TestForeginAutoClose(t *testing.T) {
t.Fatalf("failed to serve manager: %s", err)
}
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, toURL(srvCfg2)[0], "anotherpeer"); err == nil {
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer")
}
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
// Add the disconnect listener after the connection attempt
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
t.Logf("failed to add close listener (expected if connection failed): %s", err)
}
// Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
select {
case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
}
t.Logf("closing manager")
@@ -301,7 +324,6 @@ func TestForeginAutoClose(t *testing.T) {
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
reconnectingTimeout = 2 * time.Second
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
@@ -312,8 +334,7 @@ func TestAutoReconnect(t *testing.T) {
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
if err := srv.Listen(srvCfg); err != nil {
errChan <- err
}
}()

View File

@@ -4,38 +4,76 @@ import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
log "github.com/sirupsen/logrus"
)
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
t.Error("unexpected timeout")
case <-time.After(1 * time.Second):
// Test passes if no timeout received
}
}
func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 1 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
select {
case <-r.OnTimeout:
// Test passes if timeout is received
case <-time.After(2 * time.Second):
t.Error("timeout not received")
}
}
func TestNewReceiverAck(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 2 * time.Second
testMutex.Unlock()
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop()
r.Heartbeat()
@@ -59,13 +97,18 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) {
testMutex.Lock()
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() {
testMutex.Lock()
healthCheckInterval = originalInterval
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))

View File

@@ -135,7 +135,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
defer cancel()
sender := NewSender(log.WithField("test_name", tc.name))
go sender.StartHealthCheck(ctx)
senderExit := make(chan struct{})
go func() {
sender.StartHealthCheck(ctx)
close(senderExit)
}()
go func() {
responded := false
@@ -169,6 +173,11 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
t.Fatalf("should have timed out before %s", testTimeout)
}
select {
case <-senderExit:
case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time")
}
})
}

View File

@@ -20,12 +20,12 @@ type Metrics struct {
TransferBytesRecv metric.Int64Counter
AuthenticationTime metric.Float64Histogram
PeerStoreTime metric.Float64Histogram
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
peerReconnections metric.Int64Counter
peers metric.Int64UpDownCounter
peerActivityChan chan string
peerLastActive map[string]time.Time
mutexActivity sync.Mutex
ctx context.Context
}
func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
@@ -80,6 +80,13 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
return nil, err
}
peerReconnections, err := meter.Int64Counter("relay_peer_reconnections_total",
metric.WithDescription("Total number of times peers have reconnected and closed old connections"),
)
if err != nil {
return nil, err
}
m := &Metrics{
Meter: meter,
TransferBytesSent: bytesSent,
@@ -87,6 +94,7 @@ func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
AuthenticationTime: authTime,
PeerStoreTime: peerStoreTime,
peers: peers,
peerReconnections: peerReconnections,
ctx: ctx,
peerActivityChan: make(chan string, 10),
@@ -138,6 +146,10 @@ func (m *Metrics) PeerDisconnected(id string) {
delete(m.peerLastActive, id)
}
func (m *Metrics) RecordPeerReconnection() {
m.peerReconnections.Add(m.ctx, 1)
}
// PeerActivity increases the active connections
func (m *Metrics) PeerActivity(peerID string) {
select {

View File

@@ -18,12 +18,9 @@ type Listener struct {
TLSConfig *tls.Config
listener *quic.Listener
acceptFn func(conn net.Conn)
}
func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
l.acceptFn = acceptFn
quicCfg := &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1452,
@@ -49,7 +46,7 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
log.Infof("QUIC client connected from: %s", session.RemoteAddr())
conn := NewConn(session)
l.acceptFn(conn)
acceptFn(conn)
}
}

View File

@@ -32,6 +32,9 @@ type Peer struct {
notifier *store.PeerNotifier
peersListener *store.Listener
// between the online peer collection step and the notification sending should not be sent offline notifications from another thread
notificationMutex sync.Mutex
}
// NewPeer creates a new Peer instance and prepare custom logging
@@ -241,10 +244,16 @@ func (p *Peer) handleSubscribePeerState(msg []byte) {
}
p.log.Debugf("received subscription message for %d peers", len(peerIDs))
onlinePeers := p.peersListener.AddInterestedPeers(peerIDs)
// collect online peers to response back to the caller
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
onlinePeers := p.store.GetOnlinePeersAndRegisterInterest(peerIDs, p.peersListener)
if len(onlinePeers) == 0 {
return
}
p.log.Debugf("response with %d online peers", len(onlinePeers))
p.sendPeersOnline(onlinePeers)
}
@@ -274,6 +283,9 @@ func (p *Peer) sendPeersOnline(peers []messages.PeerID) {
}
func (p *Peer) sendPeersWentOffline(peers []messages.PeerID) {
p.notificationMutex.Lock()
defer p.notificationMutex.Unlock()
msgs, err := messages.MarshalPeersWentOffline(peers)
if err != nil {
p.log.Errorf("failed to marshal peer location message: %s", err)

View File

@@ -86,14 +86,13 @@ func NewRelay(config Config) (*Relay, error) {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
peerStore := store.NewStore()
r := &Relay{
metrics: m,
metricsCancel: metricsCancel,
validator: config.AuthValidator,
instanceURL: config.instanceURL,
store: peerStore,
notifier: store.NewPeerNotifier(peerStore),
store: store.NewStore(),
notifier: store.NewPeerNotifier(),
}
r.preparedMsg, err = newPreparedMsg(r.instanceURL)
@@ -131,15 +130,18 @@ func (r *Relay) Accept(conn net.Conn) {
peer := NewPeer(r.metrics, *peerID, conn, r.store, r.notifier)
peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
storeTime := time.Now()
r.store.AddPeer(peer)
if isReconnection := r.store.AddPeer(peer); isReconnection {
r.metrics.RecordPeerReconnection()
}
r.notifier.PeerCameOnline(peer.ID())
r.metrics.RecordPeerStoreTime(time.Since(storeTime))
r.metrics.PeerConnected(peer.String())
go func() {
peer.Work()
r.notifier.PeerWentOffline(peer.ID())
r.store.DeletePeer(peer)
if deleted := r.store.DeletePeer(peer); deleted {
r.notifier.PeerWentOffline(peer.ID())
}
peer.log.Debugf("relay connection closed")
r.metrics.PeerDisconnected(peer.String())
}()

View File

@@ -7,24 +7,27 @@ import (
"github.com/netbirdio/netbird/relay/messages"
)
type Listener struct {
ctx context.Context
store *Store
type event struct {
peerID messages.PeerID
online bool
}
onlineChan chan messages.PeerID
offlineChan chan messages.PeerID
type Listener struct {
ctx context.Context
eventChan chan *event
interestedPeersForOffline map[messages.PeerID]struct{}
interestedPeersForOnline map[messages.PeerID]struct{}
mu sync.RWMutex
}
func newListener(ctx context.Context, store *Store) *Listener {
func newListener(ctx context.Context) *Listener {
l := &Listener{
ctx: ctx,
store: store,
ctx: ctx,
onlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
offlineChan: make(chan messages.PeerID, 244), //244 is the message size limit in the relay protocol
// important to use a single channel for offline and online events because with it we can ensure all events
// will be processed in the order they were sent
eventChan: make(chan *event, 244), //244 is the message size limit in the relay protocol
interestedPeersForOffline: make(map[messages.PeerID]struct{}),
interestedPeersForOnline: make(map[messages.PeerID]struct{}),
}
@@ -32,8 +35,7 @@ func newListener(ctx context.Context, store *Store) *Listener {
return l
}
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.PeerID {
availablePeers := make([]messages.PeerID, 0)
func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) {
l.mu.Lock()
defer l.mu.Unlock()
@@ -41,17 +43,6 @@ func (l *Listener) AddInterestedPeers(peerIDs []messages.PeerID) []messages.Peer
l.interestedPeersForOnline[id] = struct{}{}
l.interestedPeersForOffline[id] = struct{}{}
}
// collect online peers to response back to the caller
for _, id := range peerIDs {
_, ok := l.store.Peer(id)
if !ok {
continue
}
availablePeers = append(availablePeers, id)
}
return availablePeers
}
func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
@@ -61,7 +52,6 @@ func (l *Listener) RemoveInterestedPeer(peerIDs []messages.PeerID) {
for _, id := range peerIDs {
delete(l.interestedPeersForOffline, id)
delete(l.interestedPeersForOnline, id)
}
}
@@ -70,26 +60,31 @@ func (l *Listener) listenForEvents(onPeersComeOnline, onPeersWentOffline func([]
select {
case <-l.ctx.Done():
return
case pID := <-l.onlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.onlineChan) > 0 {
pID = <-l.onlineChan
peers = append(peers, pID)
case e := <-l.eventChan:
peersOffline := make([]messages.PeerID, 0)
peersOnline := make([]messages.PeerID, 0)
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
onPeersComeOnline(peers)
case pID := <-l.offlineChan:
peers := make([]messages.PeerID, 0)
peers = append(peers, pID)
for len(l.offlineChan) > 0 {
pID = <-l.offlineChan
peers = append(peers, pID)
// Drain the channel to collect all events
for len(l.eventChan) > 0 {
e = <-l.eventChan
if e.online {
peersOnline = append(peersOnline, e.peerID)
} else {
peersOffline = append(peersOffline, e.peerID)
}
}
onPeersWentOffline(peers)
if len(peersOnline) > 0 {
onPeersComeOnline(peersOnline)
}
if len(peersOffline) > 0 {
onPeersWentOffline(peersOffline)
}
}
}
}
@@ -100,7 +95,10 @@ func (l *Listener) peerWentOffline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOffline[peerID]; ok {
select {
case l.offlineChan <- peerID:
case l.eventChan <- &event{
peerID: peerID,
online: false,
}:
case <-l.ctx.Done():
}
}
@@ -112,9 +110,13 @@ func (l *Listener) peerComeOnline(peerID messages.PeerID) {
if _, ok := l.interestedPeersForOnline[peerID]; ok {
select {
case l.onlineChan <- peerID:
case l.eventChan <- &event{
peerID: peerID,
online: true,
}:
case <-l.ctx.Done():
}
delete(l.interestedPeersForOnline, peerID)
}
}

View File

@@ -8,15 +8,12 @@ import (
)
type PeerNotifier struct {
store *Store
listeners map[*Listener]context.CancelFunc
listenersMutex sync.RWMutex
}
func NewPeerNotifier(store *Store) *PeerNotifier {
func NewPeerNotifier() *PeerNotifier {
pn := &PeerNotifier{
store: store,
listeners: make(map[*Listener]context.CancelFunc),
}
return pn
@@ -24,7 +21,7 @@ func NewPeerNotifier(store *Store) *PeerNotifier {
func (pn *PeerNotifier) NewListener(onPeersComeOnline, onPeersWentOffline func([]messages.PeerID)) *Listener {
ctx, cancel := context.WithCancel(context.Background())
listener := newListener(ctx, pn.store)
listener := newListener(ctx)
go listener.listenForEvents(onPeersComeOnline, onPeersWentOffline)
pn.listenersMutex.Lock()

View File

@@ -26,7 +26,9 @@ func NewStore() *Store {
}
// AddPeer adds a peer to the store
func (s *Store) AddPeer(peer IPeer) {
// If the peer already exists, it will be replaced and the old peer will be closed
// Returns true if the peer was replaced, false if it was added for the first time.
func (s *Store) AddPeer(peer IPeer) bool {
s.peersLock.Lock()
defer s.peersLock.Unlock()
odlPeer, ok := s.peers[peer.ID()]
@@ -35,22 +37,24 @@ func (s *Store) AddPeer(peer IPeer) {
}
s.peers[peer.ID()] = peer
return ok
}
// DeletePeer deletes a peer from the store
func (s *Store) DeletePeer(peer IPeer) {
func (s *Store) DeletePeer(peer IPeer) bool {
s.peersLock.Lock()
defer s.peersLock.Unlock()
dp, ok := s.peers[peer.ID()]
if !ok {
return
return false
}
if dp != peer {
return
return false
}
delete(s.peers, peer.ID())
return true
}
// Peer returns a peer by its ID
@@ -73,3 +77,21 @@ func (s *Store) Peers() []IPeer {
}
return peers
}
func (s *Store) GetOnlinePeersAndRegisterInterest(peerIDs []messages.PeerID, listener *Listener) []messages.PeerID {
s.peersLock.RLock()
defer s.peersLock.RUnlock()
onlinePeers := make([]messages.PeerID, 0, len(peerIDs))
listener.AddInterestedPeers(peerIDs)
// Check for currently online peers
for _, id := range peerIDs {
if _, ok := s.peers[id]; ok {
onlinePeers = append(onlinePeers, id)
}
}
return onlinePeers
}