mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[client] Extend WG watcher for ICE connection too (#5133)
Extend WG watcher for ICE connection too
This commit is contained in:
@@ -99,7 +99,10 @@ type Conn struct {
|
||||
|
||||
workerICE *WorkerICE
|
||||
workerRelay *WorkerRelay
|
||||
wgWatcherWg sync.WaitGroup
|
||||
|
||||
wgWatcher *WGWatcher
|
||||
wgWatcherWg sync.WaitGroup
|
||||
wgWatcherCancel context.CancelFunc
|
||||
|
||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||
rosenpassRemoteKey []byte
|
||||
@@ -127,6 +130,7 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
|
||||
connLog := log.WithField("peer", config.Key)
|
||||
|
||||
dumpState := newStateDump(config.Key, connLog, services.StatusRecorder)
|
||||
var conn = &Conn{
|
||||
Log: connLog,
|
||||
config: config,
|
||||
@@ -138,8 +142,9 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
|
||||
dumpState: dumpState,
|
||||
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
|
||||
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
@@ -163,7 +168,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
|
||||
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
|
||||
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
|
||||
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
@@ -232,7 +237,9 @@ func (conn *Conn) Close(signalToRemote bool) {
|
||||
conn.Log.Infof("close peer connection")
|
||||
conn.ctxCancel()
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
if conn.wgWatcherCancel != nil {
|
||||
conn.wgWatcherCancel()
|
||||
}
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.workerICE.Close()
|
||||
|
||||
@@ -374,9 +381,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
ep = directEp
|
||||
}
|
||||
|
||||
conn.workerRelay.DisableWgWatcher()
|
||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
conn.wgProxyRelay.Pause()
|
||||
}
|
||||
@@ -398,6 +402,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo)
|
||||
@@ -431,11 +437,6 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
@@ -452,15 +453,15 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
}
|
||||
conn.statusICE.SetDisconnected()
|
||||
|
||||
conn.disableWgWatcherIfNeeded()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
Relayed: conn.isRelayed(),
|
||||
ConnStatusUpdate: time.Now(),
|
||||
}
|
||||
|
||||
err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
|
||||
if err != nil {
|
||||
if err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState); err != nil {
|
||||
conn.Log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -508,11 +509,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.workerRelay.EnableWgWatcher(conn.ctx)
|
||||
}()
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
wgConfigWorkaround()
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
@@ -527,7 +524,11 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
func (conn *Conn) onRelayDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
conn.handleRelayDisconnectedLocked()
|
||||
}
|
||||
|
||||
// handleRelayDisconnectedLocked handles relay disconnection. Caller must hold conn.mu.
|
||||
func (conn *Conn) handleRelayDisconnectedLocked() {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
@@ -553,6 +554,8 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
}
|
||||
conn.statusRelay.SetDisconnected()
|
||||
|
||||
conn.disableWgWatcherIfNeeded()
|
||||
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
ConnStatus: conn.evalStatus(),
|
||||
@@ -571,6 +574,28 @@ func (conn *Conn) onGuardEvent() {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) onWGDisconnected() {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn.Log.Warnf("WireGuard handshake timeout detected, closing current connection")
|
||||
|
||||
// Close the active connection based on current priority
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay:
|
||||
conn.workerRelay.CloseConn()
|
||||
conn.handleRelayDisconnectedLocked()
|
||||
case conntype.ICEP2P, conntype.ICETurn:
|
||||
conn.workerICE.Close()
|
||||
default:
|
||||
conn.Log.Debugf("No active connection to close on WG timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
|
||||
peerState := State{
|
||||
PubKey: conn.config.Key,
|
||||
@@ -697,6 +722,25 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded() {
|
||||
if !conn.wgWatcher.IsEnabled() {
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, conn.onWGDisconnected)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) disableWgWatcherIfNeeded() {
|
||||
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
|
||||
conn.wgWatcherCancel()
|
||||
conn.wgWatcherCancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||
udpAddr := &net.UDPAddr{
|
||||
|
||||
@@ -30,10 +30,8 @@ type WGWatcher struct {
|
||||
peerKey string
|
||||
stateDump *stateDump
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctxLock sync.Mutex
|
||||
enabledTime time.Time
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
}
|
||||
|
||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||
@@ -46,52 +44,44 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
}
|
||||
|
||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.ctxLock.Lock()
|
||||
w.enabledTime = time.Now()
|
||||
|
||||
if w.ctx != nil && w.ctx.Err() == nil {
|
||||
w.log.Errorf("WireGuard watcher already enabled")
|
||||
w.ctxLock.Unlock()
|
||||
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, onDisconnectedFn func()) {
|
||||
w.muEnabled.Lock()
|
||||
if w.enabled {
|
||||
w.muEnabled.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(parentCtx)
|
||||
w.ctx = ctx
|
||||
w.ctxCancel = ctxCancel
|
||||
w.ctxLock.Unlock()
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
enabledTime := time.Now()
|
||||
w.enabled = true
|
||||
w.muEnabled.Unlock()
|
||||
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
|
||||
w.periodicHandshakeCheck(ctx, ctxCancel, onDisconnectedFn, initialHandshake)
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, enabledTime, initialHandshake)
|
||||
|
||||
w.muEnabled.Lock()
|
||||
w.enabled = false
|
||||
w.muEnabled.Unlock()
|
||||
}
|
||||
|
||||
// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit
|
||||
func (w *WGWatcher) DisableWgWatcher() {
|
||||
w.ctxLock.Lock()
|
||||
defer w.ctxLock.Unlock()
|
||||
|
||||
if w.ctxCancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
w.log.Debugf("disable WireGuard watcher")
|
||||
|
||||
w.ctxCancel()
|
||||
w.ctxCancel = nil
|
||||
// IsEnabled returns true if the WireGuard watcher is currently enabled
|
||||
func (w *WGWatcher) IsEnabled() bool {
|
||||
w.muEnabled.RLock()
|
||||
defer w.muEnabled.RUnlock()
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
|
||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||
w.log.Infof("WireGuard watcher started")
|
||||
|
||||
timer := time.NewTimer(wgHandshakeOvertime)
|
||||
defer timer.Stop()
|
||||
defer ctxCancel()
|
||||
|
||||
lastHandshake := initialHandshake
|
||||
|
||||
@@ -104,7 +94,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
||||
return
|
||||
}
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := handshake.Sub(w.enabledTime).Seconds()
|
||||
elapsed := calcElapsed(enabledTime, *handshake)
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
}
|
||||
|
||||
@@ -134,19 +124,19 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
|
||||
|
||||
// the current know handshake did not change
|
||||
if handshake.Equal(lastHandshake) {
|
||||
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// in case if the machine is suspended, the handshake time will be in the past
|
||||
if handshake.Add(checkPeriod).Before(time.Now()) {
|
||||
w.log.Warnf("WireGuard handshake timed out, closing relay connection: %v", handshake)
|
||||
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// error handling for handshake time in the future
|
||||
if handshake.After(time.Now()) {
|
||||
w.log.Warnf("WireGuard handshake is in the future, closing relay connection: %v", handshake)
|
||||
w.log.Warnf("WireGuard handshake is in the future: %v", handshake)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
@@ -164,3 +154,13 @@ func (w *WGWatcher) wgState() (time.Time, error) {
|
||||
}
|
||||
return wgState.LastHandshake, nil
|
||||
}
|
||||
|
||||
// calcElapsed calculates elapsed time since watcher was enabled.
|
||||
// The watcher started after the wg configuration happens, because of this need to normalise the negative value
|
||||
func calcElapsed(enabledTime, handshake time.Time) float64 {
|
||||
elapsed := handshake.Sub(enabledTime).Seconds()
|
||||
if elapsed < 0 {
|
||||
elapsed = 0
|
||||
}
|
||||
return elapsed
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -48,7 +49,6 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
watcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
@@ -60,14 +60,21 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
watcher.EnableWgWatcher(ctx, func() {})
|
||||
}()
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Re-enable with a new context
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
|
||||
go watcher.EnableWgWatcher(ctx, func() {})
|
||||
time.Sleep(1 * time.Second)
|
||||
watcher.DisableWgWatcher()
|
||||
|
||||
go watcher.EnableWgWatcher(ctx, func() {
|
||||
onDisconnected <- struct{}{}
|
||||
})
|
||||
@@ -80,5 +87,4 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
watcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
@@ -30,11 +30,9 @@ type WorkerRelay struct {
|
||||
relayLock sync.Mutex
|
||||
|
||||
relaySupportedOnRemotePeer atomic.Bool
|
||||
|
||||
wgWatcher *WGWatcher
|
||||
}
|
||||
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager, stateDump *stateDump) *WorkerRelay {
|
||||
func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager *relayClient.Manager) *WorkerRelay {
|
||||
r := &WorkerRelay{
|
||||
peerCtx: ctx,
|
||||
log: log,
|
||||
@@ -42,7 +40,6 @@ func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnC
|
||||
config: config,
|
||||
conn: conn,
|
||||
relayManager: relayManager,
|
||||
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -93,14 +90,6 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
|
||||
w.wgWatcher.EnableWgWatcher(ctx, w.onWGDisconnected)
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) DisableWgWatcher() {
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||
return w.relayManager.RelayInstanceAddress()
|
||||
}
|
||||
@@ -125,14 +114,6 @@ func (w *WorkerRelay) CloseConn() {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onWGDisconnected() {
|
||||
w.relayLock.Lock()
|
||||
_ = w.relayedConn.Close()
|
||||
w.relayLock.Unlock()
|
||||
|
||||
w.conn.onRelayDisconnected()
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
|
||||
if !w.relayManager.HasRelayAddress() {
|
||||
return false
|
||||
@@ -148,6 +129,5 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st
|
||||
}
|
||||
|
||||
func (w *WorkerRelay) onRelayClientDisconnected() {
|
||||
w.wgWatcher.DisableWgWatcher()
|
||||
go w.conn.onRelayDisconnected()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user