mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:19 -04:00
Reset WireGuard endpoint on ICE session change during relay fallback (#5283)
When an ICE connection disconnects and falls back to relay, reset the WireGuard endpoint and handshake watcher if the remote peer's ICE session has changed. This ensures the controller re-establishes a fresh WireGuard handshake rather than waiting on a stale endpoint from the previous session.
This commit is contained in:
@@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) onICEStateDisconnected() {
|
func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
defer conn.mu.Unlock()
|
defer conn.mu.Unlock()
|
||||||
|
|
||||||
@@ -430,6 +430,10 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
conn.Log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
conn.dumpState.SwitchToRelay()
|
conn.dumpState.SwitchToRelay()
|
||||||
|
if sessionChanged {
|
||||||
|
conn.resetEndpoint()
|
||||||
|
}
|
||||||
|
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||||
@@ -757,6 +761,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
|||||||
return wgProxy, nil
|
return wgProxy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (conn *Conn) resetEndpoint() {
|
||||||
|
if !isController(conn.config) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.Log.Infof("reset wg endpoint")
|
||||||
|
conn.wgWatcher.Reset()
|
||||||
|
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
|
||||||
|
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (conn *Conn) isReadyToUpgrade() bool {
|
func (conn *Conn) isReadyToUpgrade() bool {
|
||||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
|||||||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||||
|
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||||
if e.cancelFunc == nil {
|
if e.cancelFunc == nil {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ type WGWatcher struct {
|
|||||||
|
|
||||||
enabled bool
|
enabled bool
|
||||||
muEnabled sync.RWMutex
|
muEnabled sync.RWMutex
|
||||||
|
|
||||||
|
resetCh chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||||
@@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
|||||||
wgIfaceStater: wgIfaceStater,
|
wgIfaceStater: wgIfaceStater,
|
||||||
peerKey: peerKey,
|
peerKey: peerKey,
|
||||||
stateDump: stateDump,
|
stateDump: stateDump,
|
||||||
|
resetCh: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool {
|
|||||||
return w.enabled
|
return w.enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||||
|
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||||
|
func (w *WGWatcher) Reset() {
|
||||||
|
select {
|
||||||
|
case w.resetCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
// wgStateCheck help to check the state of the WireGuard handshake and relay connection
|
||||||
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) {
|
||||||
w.log.Infof("WireGuard watcher started")
|
w.log.Infof("WireGuard watcher started")
|
||||||
@@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
|||||||
w.stateDump.WGcheckSuccess()
|
w.stateDump.WGcheckSuccess()
|
||||||
|
|
||||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||||
|
case <-w.resetCh:
|
||||||
|
w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout")
|
||||||
|
lastHandshake = time.Time{}
|
||||||
|
enabledTime = time.Now()
|
||||||
|
timer.Stop()
|
||||||
|
timer.Reset(wgHandshakeOvertime)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
w.log.Infof("WireGuard watcher stopped")
|
w.log.Infof("WireGuard watcher stopped")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -52,8 +52,9 @@ type WorkerICE struct {
|
|||||||
// increase by one when disconnecting the agent
|
// increase by one when disconnecting the agent
|
||||||
// with it the remote peer can discard the already deprecated offer/answer
|
// with it the remote peer can discard the already deprecated offer/answer
|
||||||
// Without it the remote peer may recreate a workable ICE connection
|
// Without it the remote peer may recreate a workable ICE connection
|
||||||
sessionID ICESessionID
|
sessionID ICESessionID
|
||||||
muxAgent sync.Mutex
|
remoteSessionChanged bool
|
||||||
|
muxAgent sync.Mutex
|
||||||
|
|
||||||
localUfrag string
|
localUfrag string
|
||||||
localPwd string
|
localPwd string
|
||||||
@@ -106,6 +107,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
|
w.remoteSessionChanged = true
|
||||||
w.agentDialerCancel()
|
w.agentDialerCancel()
|
||||||
if w.agent != nil {
|
if w.agent != nil {
|
||||||
if err := w.agent.Close(); err != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
@@ -306,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) {
|
func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool {
|
||||||
cancel()
|
cancel()
|
||||||
if err := agent.Close(); err != nil {
|
if err := agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
|
defer w.muxAgent.Unlock()
|
||||||
|
|
||||||
|
sessionChanged := w.remoteSessionChanged
|
||||||
|
w.remoteSessionChanged = false
|
||||||
|
|
||||||
if w.agent == agent {
|
if w.agent == agent {
|
||||||
// consider to remove from here and move to the OnNewOffer
|
// consider to remove from here and move to the OnNewOffer
|
||||||
@@ -325,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.remoteSessionID = ""
|
w.remoteSessionID = ""
|
||||||
}
|
}
|
||||||
w.muxAgent.Unlock()
|
return sessionChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
@@ -426,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
w.closeAgent(agent, dialerCancel)
|
sessionChanged := w.closeAgent(agent, dialerCancel)
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected(sessionChanged)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user