Compare commits

...

1 Commits

16 changed files with 184 additions and 34 deletions

View File

@@ -136,6 +136,15 @@ func (p *ProxyBind) CloseConn() error {
return p.close()
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *ProxyBind) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return fmt.Errorf("proxy not started")
}
_, err := p.remoteConn.Write(b)
return err
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil

View File

@@ -219,6 +219,15 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Unlock()
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *ProxyWrapper) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return fmt.Errorf("proxy not started")
}
_, err := p.remoteConn.Write(b)
return err
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {

View File

@@ -18,4 +18,8 @@ type Proxy interface {
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation.
InjectPacket(b []byte) error
}

View File

@@ -147,6 +147,15 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.sendPkg = p.srcFakerConn.SendPkg
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *WGUDPProxy) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return fmt.Errorf("proxy not started")
}
_, err := p.remoteConn.Write(b)
return err
}
// CloseConn close the localConn
func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil {

View File

@@ -224,6 +224,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) MTU() uint16 {
return 1280
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}

View File

@@ -44,4 +44,5 @@ type wgIfaceBase interface {
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
MTU() uint16
}

View File

@@ -4,15 +4,21 @@ import (
"context"
"io"
"net"
"sync"
"time"
)
// lazyConn detects activity when WireGuard attempts to send packets.
// It does not deliver packets, only signals that activity occurred.
// It does not deliver packets, only signals that activity occurred. The first packet WG-go writes
// is captured so the lazyconn manager can reinject it through the real bind endpoint, avoiding the
// handshake-retry wait when the real connection comes up.
type lazyConn struct {
activityCh chan struct{}
ctx context.Context
cancel context.CancelFunc
mu sync.Mutex
capturedPacket []byte
}
// newLazyConn creates a new lazyConn for activity detection.
@@ -31,12 +37,19 @@ func (c *lazyConn) Read(_ []byte) (n int, err error) {
return 0, io.EOF
}
// Write signals activity detection when ICEBind routes packets to this endpoint.
// Write signals activity detection when ICEBind routes packets to this endpoint. The first packet
// is cloned and stored so it can be replayed on the real endpoint once ICE/relay comes up.
func (c *lazyConn) Write(b []byte) (n int, err error) {
if c.ctx.Err() != nil {
return 0, io.EOF
}
c.mu.Lock()
if c.capturedPacket == nil && len(b) > 0 {
c.capturedPacket = append([]byte(nil), b...)
}
c.mu.Unlock()
select {
case c.activityCh <- struct{}{}:
default:
@@ -45,6 +58,13 @@ func (c *lazyConn) Write(b []byte) (n int, err error) {
return len(b), nil
}
// CapturedPacket returns a copy of the first packet written to this connection, or nil.
func (c *lazyConn) CapturedPacket() []byte {
c.mu.Lock()
defer c.mu.Unlock()
return c.capturedPacket
}
// ActivityChan returns the channel that signals when activity is detected.
func (c *lazyConn) ActivityChan() <-chan struct{} {
return c.activityCh

View File

@@ -118,16 +118,23 @@ func (d *BindListener) ReadPackets() {
d.peerCfg.Log.Infof("exit from activity listener")
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
// removing the peer here wipes wireguard-go's staged queue and drops the user packet that
// triggered activation.
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()
}
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
// Safe to call after ReadPackets returns.
func (d *BindListener) CapturedPacket() []byte {
if d.lazyConn == nil {
return nil
}
return d.lazyConn.CapturedPacket()
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")

View File

@@ -68,6 +68,10 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func (m *MockWGIfaceBind) MTU() uint16 {
return 1280
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -207,8 +211,9 @@ func TestManager_BindMode(t *testing.T) {
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case ev := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
assert.Equal(t, []byte{0x01, 0x02, 0x03}, ev.FirstPacket, "First packet should be captured for reinjection")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
@@ -266,8 +271,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case ev := <-mgr.OnActivityChan:
receivedPeers[ev.PeerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}

View File

@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
@@ -20,6 +21,8 @@ type UDPListener struct {
done sync.Mutex
isClosed atomic.Bool
capturedPacket []byte
}
// NewUDPListener creates a listener that detects activity via UDP socket reads.
@@ -46,9 +49,13 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
}
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
// The first packet that triggers activity is captured so it can be reinjected through the real
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
// dropped and WG would only retry after REKEY_TIMEOUT.
func (d *UDPListener) ReadPackets() {
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Infof("exit from activity listener")
@@ -62,20 +69,24 @@ func (d *UDPListener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
d.peerCfg.Log.Infof("activity detected")
d.capturedPacket = append([]byte(nil), buf[:n]...)
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
break
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
// Ignore close error as it may return "use of closed network connection" if already closed.
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
// triggered activation.
_ = d.conn.Close()
d.done.Unlock()
}
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
// Safe to call after ReadPackets returns.
func (d *UDPListener) CapturedPacket() []byte {
return d.capturedPacket
}
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())

View File

@@ -19,6 +19,14 @@ import (
type listener interface {
ReadPackets()
Close()
CapturedPacket() []byte
}
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
// captured for reinjection through the real transport.
type Event struct {
PeerConnID peerid.ConnID
FirstPacket []byte
}
type WgInterface interface {
@@ -26,10 +34,11 @@ type WgInterface interface {
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
MTU() uint16
}
type Manager struct {
OnActivityChan chan peerid.ConnID
OnActivityChan chan Event
wgIface WgInterface
@@ -41,7 +50,7 @@ type Manager struct {
func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
OnActivityChan: make(chan Event, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
@@ -116,12 +125,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
func (m *Manager) notify(ev Event) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
case m.OnActivityChan <- ev:
}
}

View File

@@ -44,6 +44,10 @@ func (m MocWGIface) Address() wgaddr.Address {
}
}
func (m MocWGIface) MTU() uint16 {
return 1280
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
@@ -86,9 +90,9 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
case ev := <-mgr.OnActivityChan:
if ev.PeerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
}
case <-time.After(1 * time.Second):
}

View File

@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(peerConnID)
case ev := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ev)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
}
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
func (m *Manager) onPeerActivity(ev activity.Event) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
return
}
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
}
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {

View File

@@ -17,4 +17,5 @@ type WGIface interface {
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
MTU() uint16
}

View File

@@ -136,6 +136,51 @@ type Conn struct {
// Connection stage timestamps for metrics
metricsRecorder MetricsRecorder
metricsStages *MetricsStages
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
// transport is up.
pendingFirstPacket []byte
}
// SetPendingFirstPacket stashes a packet to be replayed once the connection is established.
func (conn *Conn) SetPendingFirstPacket(b []byte) {
conn.mu.Lock()
conn.pendingFirstPacket = append([]byte(nil), b...)
conn.mu.Unlock()
}
// takePendingFirstPacket returns and clears the stashed first packet. Caller must hold conn.mu.
func (conn *Conn) takePendingFirstPacket() []byte {
pkt := conn.pendingFirstPacket
conn.pendingFirstPacket = nil
return pkt
}
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
// directly through the ICE conn. Caller must hold conn.mu.
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
pkt := conn.takePendingFirstPacket()
if len(pkt) == 0 {
return
}
switch {
case proxy != nil:
if err := proxy.InjectPacket(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
return
}
case directConn != nil:
if _, err := directConn.Write(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
return
}
default:
conn.Log.Debugf("no transport available to reinject captured first packet")
return
}
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -423,6 +468,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo, updateTime)
@@ -546,6 +593,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.injectPendingFirstPacket(wgProxy, nil)
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()

View File

@@ -81,6 +81,12 @@ func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
}
func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
s.PeerConnOpenWithFirstPacket(ctx, pubKey, nil)
}
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
// reinjected once the real transport is established.
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
@@ -88,11 +94,13 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
if !ok {
return
}
if len(firstPacket) > 0 {
p.SetPendingFirstPacket(firstPacket)
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnIdle(pubKey string) {