[client] Extend WG watcher for ICE connection too (#5133)

Extend WG watcher for ICE connection too
This commit is contained in:
Zoltan Papp
2026-01-21 10:42:13 +01:00
committed by GitHub
parent 030650a905
commit e908dea702
4 changed files with 114 additions and 84 deletions

View File

@@ -99,7 +99,10 @@ type Conn struct {
workerICE *WorkerICE workerICE *WorkerICE
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup
wgWatcher *WGWatcher
wgWatcherWg sync.WaitGroup
wgWatcherCancel context.CancelFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte rosenpassRemoteKey []byte
@@ -127,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
var conn = &Conn{ var conn = &Conn{
Log: connLog, Log: connLog,
config: config, config: config,
@@ -138,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
semaphore: services.Semaphore, semaphore: services.Semaphore,
statusRelay: worker.NewAtomicStatus(), statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
} }
return conn, nil return conn, nil
@@ -163,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -232,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.Log.Infof("close peer connection") conn.Log.Infof("close peer connection")
conn.ctxCancel() conn.ctxCancel()
conn.workerRelay.DisableWgWatcher() if conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
conn.workerICE.Close() conn.workerICE.Close()
@@ -374,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp ep = directEp
} }
conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
conn.wgProxyRelay.Pause() conn.wgProxyRelay.Pause()
} }
@@ -398,6 +402,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep) conn.wgProxyRelay.RedirectAs(ep)
} }
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
@@ -431,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
conn.Log.Errorf("failed to switch to relay conn: %v", err) conn.Log.Errorf("failed to switch to relay conn: %v", err)
} }
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay
} else { } else {
@@ -452,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
} }
conn.statusICE.SetDisconnected() conn.statusICE.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(), ConnStatus: conn.evalStatus(),
Relayed: conn.isRelayed(), Relayed: conn.isRelayed(),
ConnStatusUpdate: time.Now(), ConnStatusUpdate: time.Now(),
} }
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
if err != nil {
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err) conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
} }
} }
@@ -508,11 +509,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
conn.wgWatcherWg.Add(1) conn.enableWgWatcherIfNeeded()
go func() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
wgConfigWorkaround() wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.rosenpassRemoteKey = rci.rosenpassPubKey
@@ -527,7 +524,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
func (conn *Conn) onRelayDisconnected() { func (conn *Conn) onRelayDisconnected() {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
conn.handleRelayDisconnectedLocked()
}
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
func (conn *Conn) handleRelayDisconnectedLocked() {
if conn.ctx.Err() != nil { if conn.ctx.Err() != nil {
return return
} }
@@ -553,6 +554,8 @@ func (conn *Conn) onRelayDisconnected() {
} }
conn.statusRelay.SetDisconnected() conn.statusRelay.SetDisconnected()
conn.disableWgWatcherIfNeeded()
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
ConnStatus: conn.evalStatus(), ConnStatus: conn.evalStatus(),
@@ -571,6 +574,28 @@ func (conn *Conn) onGuardEvent() {
} }
} }
func (conn *Conn) onWGDisconnected() {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
return
}
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
// Close the active connection based on current priority
switch conn.currentConnPriority {
case conntype.Relay:
conn.workerRelay.CloseConn()
conn.handleRelayDisconnectedLocked()
case conntype.ICEP2P, conntype.ICETurn:
conn.workerICE.Close()
default:
conn.Log.Debugf("No active connection to close on WG timeout")
}
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{ peerState := State{
PubKey: conn.config.Key, PubKey: conn.config.Key,
@@ -697,6 +722,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true return true
} }
func (conn *Conn) enableWgWatcherIfNeeded() {
if !conn.wgWatcher.IsEnabled() {
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
}()
}
}
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{

View File

@@ -30,10 +30,8 @@ type WGWatcher struct {
peerKey string peerKey string
stateDump *stateDump stateDump *stateDump
ctx context.Context enabled bool
ctxCancel context.CancelFunc muEnabled sync.RWMutex
ctxLock sync.Mutex
enabledTime time.Time
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
} }
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. // EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { // The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
w.log.Debugf("enable WireGuard watcher") func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
w.ctxLock.Lock() w.muEnabled.Lock()
w.enabledTime = time.Now() if w.enabled {
w.muEnabled.Unlock()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
w.ctxLock.Unlock()
return return
} }
ctx, ctxCancel := context.WithCancel(parentCtx) w.log.Debugf("enable WireGuard watcher")
w.ctx = ctx enabledTime := time.Now()
w.ctxCancel = ctxCancel w.enabled = true
w.ctxLock.Unlock() w.muEnabled.Unlock()
initialHandshake, err := w.wgState() initialHandshake, err := w.wgState()
if err != nil { if err != nil {
w.log.Warnf("failed to read initial wg stats: %v", err) w.log.Warnf("failed to read initial wg stats: %v", err)
} }
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake) w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
} }
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit // IsEnabled returns true if the WireGuard watcher is currently enabled
func (w *WGWatcher) DisableWgWatcher() { func (w *WGWatcher) IsEnabled() bool {
w.ctxLock.Lock() w.muEnabled.RLock()
defer w.ctxLock.Unlock() defer w.muEnabled.RUnlock()
return w.enabled
if w.ctxCancel == nil {
return
}
w.log.Debugf("disable WireGuard watcher")
w.ctxCancel()
w.ctxCancel = nil
} }
// wgStateCheck help to check the state of the WireGuard handshake and relay connection // wgStateCheck help to check the state of the WireGuard handshake and relay connection
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started") w.log.Infof("WireGuard watcher started")
timer := time.NewTimer(wgHandshakeOvertime) timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop() defer timer.Stop()
defer ctxCancel()
lastHandshake := initialHandshake lastHandshake := initialHandshake
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
return return
} }
if lastHandshake.IsZero() { if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds() elapsed := calcElapsed(enabledTime, *handshake)
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake) w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
} }
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
// the current know handshake did not change // the current know handshake did not change
if handshake.Equal(lastHandshake) { if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false return nil, false
} }
// in case if the machine is suspended, the handshake time will be in the past // in case if the machine is suspended, the handshake time will be in the past
if handshake.Add(checkPeriod).Before(time.Now()) { if handshake.Add(checkPeriod).Before(time.Now()) {
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake) w.log.Warnf("WireGuard handshake timed out: %v", handshake)
return nil, false return nil, false
} }
// error handling for handshake time in the future // error handling for handshake time in the future
if handshake.After(time.Now()) { if handshake.After(time.Now()) {
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake) w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
return nil, false return nil, false
} }
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
} }
return wgState.LastHandshake, nil return wgState.LastHandshake, nil
} }
// calcElapsed calculates elapsed time since watcher was enabled.
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
func calcElapsed(enabledTime, handshake time.Time) float64 {
elapsed := handshake.Sub(enabledTime).Seconds()
if elapsed < 0 {
elapsed = 0
}
return elapsed
}

View File

@@ -2,6 +2,7 @@ package peer
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
t.Errorf("timeout") t.Errorf("timeout")
} }
watcher.DisableWgWatcher()
} }
func TestWGWatcher_ReEnable(t *testing.T) { func TestWGWatcher_ReEnable(t *testing.T) {
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{})) watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
watcher.EnableWgWatcher(ctx, func() {})
}()
cancel()
wg.Wait()
// Re-enable with a new context
ctx, cancel = context.WithCancel(context.Background())
defer cancel() defer cancel()
onDisconnected := make(chan struct{}, 1) onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, func() {})
time.Sleep(1 * time.Second)
watcher.DisableWgWatcher()
go watcher.EnableWgWatcher(ctx, func() { go watcher.EnableWgWatcher(ctx, func() {
onDisconnected <- struct{}{} onDisconnected <- struct{}{}
}) })
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
t.Errorf("timeout") t.Errorf("timeout")
} }
watcher.DisableWgWatcher()
} }

View File

@@ -30,11 +30,9 @@ type WorkerRelay struct {
relayLock sync.Mutex relayLock sync.Mutex
relaySupportedOnRemotePeer atomic.Bool relaySupportedOnRemotePeer atomic.Bool
wgWatcher *WGWatcher
} }
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx, peerCtx: ctx,
log: log, log: log,
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
config: config, config: config,
conn: conn, conn: conn,
relayManager: relayManager, relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
} }
return r return r
} }
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}) })
} }
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
}
func (w *WorkerRelay) DisableWgWatcher() {
w.wgWatcher.DisableWgWatcher()
}
func (w *WorkerRelay) RelayInstanceAddress() (string, error) { func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress() return w.relayManager.RelayInstanceAddress()
} }
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
} }
} }
func (w *WorkerRelay) onWGDisconnected() {
w.relayLock.Lock()
_ = w.relayedConn.Close()
w.relayLock.Unlock()
w.conn.onRelayDisconnected()
}
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
if !w.relayManager.HasRelayAddress() { if !w.relayManager.HasRelayAddress() {
return false return false
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
} }
func (w *WorkerRelay) onRelayClientDisconnected() { func (w *WorkerRelay) onRelayClientDisconnected() {
w.wgWatcher.DisableWgWatcher()
go w.conn.onRelayDisconnected() go w.conn.onRelayDisconnected()
} }