From e908dea702eb4520021b0cd0806e695619777127 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 21 Jan 2026 10:42:13 +0100 Subject: [PATCH] [client] Extend WG watcher for ICE connection too (#5133) Extend WG watcher for ICE connection too --- client/internal/peer/conn.go | 84 +++++++++++++++++++------ client/internal/peer/wg_watcher.go | 72 ++++++++++----------- client/internal/peer/wg_watcher_test.go | 20 +++--- client/internal/peer/worker_relay.go | 22 +------ 4 files changed, 114 insertions(+), 84 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index ba82354a2..39133a6d3 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -99,7 +99,10 @@ type Conn struct { workerICE *WorkerICE workerRelay *WorkerRelay - wgWatcherWg sync.WaitGroup + + wgWatcher *WGWatcher + wgWatcherWg sync.WaitGroup + wgWatcherCancel context.CancelFunc // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice rosenpassRemoteKey []byte @@ -127,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) + dumpState := newStateDump(config.Key, connLog, services.StatusRecorder) var conn = &Conn{ Log: connLog, config: config, @@ -138,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { semaphore: services.Semaphore, statusRelay: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + dumpState: dumpState, endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), + wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState), } return conn, nil @@ -163,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) - conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) + conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) @@ -232,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) { conn.Log.Infof("close peer connection") conn.ctxCancel() - conn.workerRelay.DisableWgWatcher() + if conn.wgWatcherCancel != nil { + conn.wgWatcherCancel() + } conn.workerRelay.CloseConn() conn.workerICE.Close() @@ -374,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn ep = directEp } - conn.workerRelay.DisableWgWatcher() - // todo consider to run conn.wgWatcherWg.Wait() here - if conn.wgProxyRelay != nil { conn.wgProxyRelay.Pause() } @@ -398,6 +402,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn conn.wgProxyRelay.RedirectAs(ep) } + conn.enableWgWatcherIfNeeded() + conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -431,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() { conn.Log.Errorf("failed to switch to relay conn: %v", err) } - conn.wgWatcherWg.Add(1) - go func() { - defer conn.wgWatcherWg.Done() - conn.workerRelay.EnableWgWatcher(conn.ctx) - }() conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { @@ -452,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() { } conn.statusICE.SetDisconnected() + conn.disableWgWatcherIfNeeded() + peerState := State{ PubKey: conn.config.Key, ConnStatus: conn.evalStatus(), Relayed: conn.isRelayed(), ConnStatusUpdate: time.Now(), } - - err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState) - if err != nil { + if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil { conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err) } } @@ -508,11 +509,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { return } - conn.wgWatcherWg.Add(1) - go func() { - defer conn.wgWatcherWg.Done() - conn.workerRelay.EnableWgWatcher(conn.ctx) - }() + conn.enableWgWatcherIfNeeded() wgConfigWorkaround() conn.rosenpassRemoteKey = rci.rosenpassPubKey @@ -527,7 +524,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { func (conn *Conn) onRelayDisconnected() { conn.mu.Lock() defer conn.mu.Unlock() + conn.handleRelayDisconnectedLocked() +} +// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu. +func (conn *Conn) handleRelayDisconnectedLocked() { if conn.ctx.Err() != nil { return } @@ -553,6 +554,8 @@ func (conn *Conn) onRelayDisconnected() { } conn.statusRelay.SetDisconnected() + conn.disableWgWatcherIfNeeded() + peerState := State{ PubKey: conn.config.Key, ConnStatus: conn.evalStatus(), @@ -571,6 +574,28 @@ func (conn *Conn) onGuardEvent() { } } +func (conn *Conn) onWGDisconnected() { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.ctx.Err() != nil { + return + } + + conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection") + + // Close the active connection based on current priority + switch conn.currentConnPriority { + case conntype.Relay: + conn.workerRelay.CloseConn() + conn.handleRelayDisconnectedLocked() + case conntype.ICEP2P, conntype.ICETurn: + conn.workerICE.Close() + default: + conn.Log.Debugf("No active connection to close on WG timeout") + } +} + func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -697,6 +722,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { return true } +func (conn *Conn) enableWgWatcherIfNeeded() { + if !conn.wgWatcher.IsEnabled() { + wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx) + conn.wgWatcherCancel = wgWatcherCancel + conn.wgWatcherWg.Add(1) + go func() { + defer conn.wgWatcherWg.Done() + conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected) + }() + } +} + +func (conn *Conn) disableWgWatcherIfNeeded() { + if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil { + conn.wgWatcherCancel() + conn.wgWatcherCancel = nil + } +} + func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.Log.Debugf("setup proxied WireGuard connection") udpAddr := &net.UDPAddr{ diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index 0ed200fda..d40ec7a80 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -30,10 +30,8 @@ type WGWatcher struct { peerKey string stateDump *stateDump - ctx context.Context - ctxCancel context.CancelFunc - ctxLock sync.Mutex - enabledTime time.Time + enabled bool + muEnabled sync.RWMutex } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin } // EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. -func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { - w.log.Debugf("enable WireGuard watcher") - w.ctxLock.Lock() - w.enabledTime = time.Now() - - if w.ctx != nil && w.ctx.Err() == nil { - w.log.Errorf("WireGuard watcher already enabled") - w.ctxLock.Unlock() +// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management. +func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) { + w.muEnabled.Lock() + if w.enabled { + w.muEnabled.Unlock() return } - ctx, ctxCancel := context.WithCancel(parentCtx) - w.ctx = ctx - w.ctxCancel = ctxCancel - w.ctxLock.Unlock() + w.log.Debugf("enable WireGuard watcher") + enabledTime := time.Now() + w.enabled = true + w.muEnabled.Unlock() initialHandshake, err := w.wgState() if err != nil { w.log.Warnf("failed to read initial wg stats: %v", err) } - w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake) + w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake) + + w.muEnabled.Lock() + w.enabled = false + w.muEnabled.Unlock() } -// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit -func (w *WGWatcher) DisableWgWatcher() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancel == nil { - return - } - - w.log.Debugf("disable WireGuard watcher") - - w.ctxCancel() - w.ctxCancel = nil +// IsEnabled returns true if the WireGuard watcher is currently enabled +func (w *WGWatcher) IsEnabled() bool { + w.muEnabled.RLock() + defer w.muEnabled.RUnlock() + return w.enabled } // wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { +func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) { w.log.Infof("WireGuard watcher started") timer := time.NewTimer(wgHandshakeOvertime) defer timer.Stop() - defer ctxCancel() lastHandshake := initialHandshake @@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex return } if lastHandshake.IsZero() { - elapsed := handshake.Sub(w.enabledTime).Seconds() + elapsed := calcElapsed(enabledTime, *handshake) w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) } @@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) { // the current know handshake did not change if handshake.Equal(lastHandshake) { - w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) + w.log.Warnf("WireGuard handshake timed out: %v", handshake) return nil, false } // in case if the machine is suspended, the handshake time will be in the past if handshake.Add(checkPeriod).Before(time.Now()) { - w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) + w.log.Warnf("WireGuard handshake timed out: %v", handshake) return nil, false } // error handling for handshake time in the future if handshake.After(time.Now()) { - w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake) + w.log.Warnf("WireGuard handshake is in the future: %v", handshake) return nil, false } @@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) { } return wgState.LastHandshake, nil } + +// calcElapsed calculates elapsed time since watcher was enabled. +// The watcher started after the wg configuration happens, because of this need to normalise the negative value +func calcElapsed(enabledTime, handshake time.Time) float64 { + elapsed := handshake.Sub(enabledTime).Seconds() + if elapsed < 0 { + elapsed = 0 + } + return elapsed +} diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go index d7c277eff..f79405a01 100644 --- a/client/internal/peer/wg_watcher_test.go +++ b/client/internal/peer/wg_watcher_test.go @@ -2,6 +2,7 @@ package peer import ( "context" + "sync" "testing" "time" @@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) { case <-time.After(10 * time.Second): t.Errorf("timeout") } - watcher.DisableWgWatcher() } func TestWGWatcher_ReEnable(t *testing.T) { @@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) { watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{})) ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + watcher.EnableWgWatcher(ctx, func() {}) + }() + cancel() + + wg.Wait() + + // Re-enable with a new context + ctx, cancel = context.WithCancel(context.Background()) defer cancel() onDisconnected := make(chan struct{}, 1) - - go watcher.EnableWgWatcher(ctx, func() {}) - time.Sleep(1 * time.Second) - watcher.DisableWgWatcher() - go watcher.EnableWgWatcher(ctx, func() { onDisconnected <- struct{}{} }) @@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) { case <-time.After(10 * time.Second): t.Errorf("timeout") } - watcher.DisableWgWatcher() } diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index f584487f5..06309fbaf 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -30,11 +30,9 @@ type WorkerRelay struct { relayLock sync.Mutex relaySupportedOnRemotePeer atomic.Bool - - wgWatcher *WGWatcher } -func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay { +func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay { r := &WorkerRelay{ peerCtx: ctx, log: log, @@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC config: config, conn: conn, relayManager: relayManager, - wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump), } return r } @@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { }) } -func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { - w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected) -} - -func (w *WorkerRelay) DisableWgWatcher() { - w.wgWatcher.DisableWgWatcher() -} - func (w *WorkerRelay) RelayInstanceAddress() (string, error) { return w.relayManager.RelayInstanceAddress() } @@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() { } } -func (w *WorkerRelay) onWGDisconnected() { - w.relayLock.Lock() - _ = w.relayedConn.Close() - w.relayLock.Unlock() - - w.conn.onRelayDisconnected() -} - func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { if !w.relayManager.HasRelayAddress() { return false @@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st } func (w *WorkerRelay) onRelayClientDisconnected() { - w.wgWatcher.DisableWgWatcher() go w.conn.onRelayDisconnected() }