Compare commits

...

38 Commits

Author SHA1 Message Date
Zoltán Papp
a4b8ddac52 Fix missing required parameter for Bind proxy test 2025-09-04 11:14:57 +02:00
Zoltán Papp
0ce32ec527 Fix unit test data 2025-09-04 10:24:57 +02:00
Zoltán Papp
0a03145709 Fix test on darwin 2025-09-04 10:13:07 +02:00
Zoltán Papp
af18385291 Merge branch 'debug-slow-conn' into debug-and-fixes 2025-09-03 19:07:52 +02:00
Zoltán Papp
b1f6e9b12d Merge branch 'fix/pkg-loss' into debug-and-fixes 2025-09-03 19:07:28 +02:00
Zoltán Papp
180b52ba26 Merge branch 'main' into fix/pkg-loss 2025-09-03 16:27:17 +02:00
Zoltán Papp
f86a7f745e Fix logging library 2025-09-03 14:08:53 +02:00
Zoltán Papp
fd13247d66 Explicit print out the first success handshake time 2025-09-03 14:06:28 +02:00
Zoltán Papp
1d83fccd9c Improve logging for management service connection and system info gathering 2025-09-03 13:28:04 +02:00
Zoltán Papp
7d3c972653 Do not block Offer processing from relay worker
- do not miss ICE offers when relay worker busy
- close p2p connection before recreate agent
2025-09-02 22:17:13 +02:00
Zoltán Papp
990aa8f7cd Remove wg initiator logic 2025-02-24 10:46:22 +01:00
Zoltán Papp
2e6a44af1f Merge branch 'main' into fix/pkg-loss 2025-02-24 10:44:32 +01:00
Zoltán Papp
559d347588 Revert async WireGuard handshake 2025-02-21 16:19:28 +01:00
Zoltán Papp
06d71257b4 Build rawsocket code on linux only. 2025-02-21 16:00:31 +01:00
Zoltán Papp
62d10496ee Merge branch 'main' into fix/pkg-loss 2025-02-21 15:52:51 +01:00
Zoltán Papp
651e88d611 Eliminate code duplication 2025-02-21 14:59:42 +01:00
Zoltán Papp
d496d21693 Apply same pausedCond logic on all implementation 2025-02-21 14:55:56 +01:00
Zoltan Papp
648b4cdf72 Update client/iface/wgproxy/udp/proxy.go
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-02-21 14:50:29 +01:00
Zoltán Papp
f0020ad4ce Fallback to package loss solution if the raw socket does not work. 2025-02-18 13:58:20 +01:00
Zoltán Papp
1eacff250e Remove WireGuard kernel code from FreeBSD 2025-02-18 13:47:42 +01:00
Zoltán Papp
1f83ba4563 Ignore err in tests 2025-02-17 22:06:04 +01:00
Zoltán Papp
3d80a25b4d Fix possible blocker if the bind will be closed earlier then proxy 2025-02-17 22:03:15 +01:00
Zoltán Papp
1963644c99 Add close test for all implementation 2025-02-17 21:47:34 +01:00
Zoltán Papp
360c7134f7 Add unit test 2025-02-17 21:26:37 +01:00
Zoltán Papp
775b4feb7e Fix close operation 2025-02-17 20:17:47 +01:00
Zoltán Papp
aca443bdec Build UDP proxy on Linux only 2025-02-17 19:26:40 +01:00
Zoltán Papp
335866ac60 Close unused rawsocket 2025-02-17 15:56:48 +01:00
Zoltán Papp
082452eb5f Add info log line 2025-02-17 15:50:49 +01:00
Zoltan Papp
b17c1d96a5 [client] Improve WireGuard handshake success rate (#3092)
The controller peer sends WireGuard handshake requests only
2025-02-17 15:50:15 +01:00
Zoltán Papp
2d5b5f59c2 Remove log line 2025-02-17 15:49:04 +01:00
Zoltán Papp
d5042f688f Fix interface changes 2025-02-17 15:47:20 +01:00
Zoltán Papp
4db73a13d7 Implement redirect logic in UDP proxy 2025-02-17 15:47:20 +01:00
Zoltán Papp
06a17f0eee Implement redirect to in eBPF proxy 2025-02-17 15:47:20 +01:00
Zoltán Papp
1f088b7e69 Extend the proxy interface with RedirectTo function and implement it in Bind proxy 2025-02-17 15:47:20 +01:00
Zoltan Papp
ffe74365a8 Code cleaning 2025-02-17 15:47:20 +01:00
Zoltán Papp
6a0f6efc18 Always stop timer 2025-02-17 15:47:20 +01:00
Zoltán Papp
bfa6df13c5 The non handshake initiator peer start the handshake after timeout 2025-02-17 15:47:20 +01:00
Zoltán Papp
e9b3b6210d Improve WireGuard handshake success rate
The controller peer sends WireGuard
handshake requests only
2025-02-17 15:47:20 +01:00
25 changed files with 787 additions and 258 deletions

View File

@@ -1,5 +1,17 @@
package bind package bind
import wgConn "golang.zx2c4.com/wireguard/conn" import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -1,6 +1,7 @@
package bind package bind
import ( import (
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"net" "net"
@@ -41,7 +42,7 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported. // use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct { type ICEBind struct {
*wgConn.StdNetBind *wgConn.StdNetBind
RecvChan chan RecvMessage recvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn FilterFn filterFn FilterFn
@@ -64,7 +65,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
RecvChan: make(chan RecvMessage, 1), recvChan: make(chan RecvMessage, 1),
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn, filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
@@ -154,6 +155,14 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil return nil
} }
func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) {
select {
case <-ctx.Done():
return
case b.recvChan <- msg:
}
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
@@ -270,7 +279,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select { select {
case <-c.closedChan: case <-c.closedChan:
return 0, net.ErrClosed return 0, net.ErrClosed
case msg, ok := <-c.RecvChan: case msg, ok := <-c.recvChan:
if !ok { if !ok {
return 0, net.ErrClosed return 0, net.ErrClosed
} }

View File

@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
if err != nil { if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err) return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
} }
// If sec is 0 (Unix epoch), return zero time instead
// This indicates no handshake has occurred
if sec == 0 {
return time.Time{}, nil
}
return time.Unix(sec, 0), nil return time.Unix(sec, 0), nil
} }

View File

@@ -0,0 +1,40 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd //go:build linux && !android
package iface package iface

View File

@@ -16,28 +16,37 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener" "github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type IceBind interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
Recv(ctx context.Context, msg bind.RecvMessage)
MTU() uint16
}
type ProxyBind struct { type ProxyBind struct {
Bind *bind.ICEBind bind IceBind
fakeNetIP *netip.AddrPort // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgBindEndpoint *bind.Endpoint wgRelayedEndpoint *bind.Endpoint
remoteConn net.Conn wgCurrentUsed *bind.Endpoint
ctx context.Context remoteConn net.Conn
cancel context.CancelFunc ctx context.Context
closeMu sync.Mutex cancel context.CancelFunc
closed bool closeMu sync.Mutex
closed bool
pausedMu sync.Mutex paused bool
paused bool pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
func NewProxyBind(bind *bind.ICEBind) *ProxyBind { func NewProxyBind(bind IceBind) *ProxyBind {
p := &ProxyBind{ p := &ProxyBind{
Bind: bind, bind: bind,
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
} }
return p return p
@@ -46,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return nil return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{ return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) { func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -76,17 +85,22 @@ func (p *ProxyBind) Work() {
return return
} }
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once // Start the proxy only once
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.L.Unlock()
// todo: review to should be inside the lock scope
p.pausedCond.Signal()
} }
func (p *ProxyBind) Pause() { func (p *ProxyBind) Pause() {
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
} }
func (p *ProxyBind) CloseConn() error { func (p *ProxyBind) CloseConn() error {
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
} }
func (p *ProxyBind) close() error { func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}() }()
for { for {
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead)
n, err := p.remoteConn.Read(buf) n, err := p.remoteConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -147,18 +186,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
msg := bind.RecvMessage{ msg := bind.RecvMessage{
Endpoint: p.wgBindEndpoint, Endpoint: p.wgCurrentUsed,
Buffer: buf[:n], Buffer: buf[:n],
} }
p.Bind.RecvChan <- msg p.bind.Recv(ctx, msg)
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
} }
} }

View File

@@ -6,9 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"os"
"sync" "sync"
"syscall"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -18,6 +16,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf" "github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -27,6 +26,10 @@ const (
loopbackAddr = "127.0.0.1" loopbackAddr = "127.0.0.1"
) )
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support // WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct { type WGEBPFProxy struct {
localWGListenPort int localWGListenPort int
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
return err return err
} }
p.rawConn, err = p.prepareSenderRawSocket() p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil { if err != nil {
return err return err
} }
@@ -214,57 +217,17 @@ generatePort:
return p.lastUsedPort, nil return p.lastUsedPort, nil
} }
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
payload := gopacket.Payload(data) payload := gopacket.Payload(data)
ipH := &layers.IPv4{ ipH := &layers.IPv4{
DstIP: localhost, DstIP: localHostNetIP,
SrcIP: localhost, SrcIP: endpointAddr.IP,
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
udpH := &layers.UDP{ udpH := &layers.UDP{
SrcPort: layers.UDPPort(port), SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort), DstPort: layers.UDPPort(p.localWGListenPort),
} }
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil { if err != nil {
return fmt.Errorf("serialize layers: %w", err) return fmt.Errorf("serialize layers: %w", err)
} }
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err) return fmt.Errorf("write to raw conn: %w", err)
} }
return nil return nil

View File

@@ -18,41 +18,42 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct { type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex paused bool
paused bool pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{ return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy, wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
} }
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr p.wgRelayedEndpointAddr = addr
return err return err
} }
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgRelayedEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -64,14 +65,19 @@ func (p *ProxyWrapper) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.L.Unlock()
// todo: review to should be inside the lock scope
p.pausedCond.Signal()
} }
func (p *ProxyWrapper) Pause() { func (p *ProxyWrapper) Pause() {
@@ -80,45 +86,59 @@ func (p *ProxyWrapper) Pause() {
} }
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
} }
// CloseConn close the remoteConn and automatically remove the conn instance from the map // CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error { func (p *ProxyWrapper) CloseConn() error {
if e.cancel == nil { if p.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
} }
e.cancel() p.cancel()
e.closeListener.SetCloseListener(nil) p.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { p.pausedCond.L.Lock()
return fmt.Errorf("close remote conn: %w", err) p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil
} }
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for { for {
n, err := p.readFromRemote(ctx, buf) n, err := p.readFromRemote(ctx, buf)
if err != nil { if err != nil {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
@@ -137,7 +157,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
} }
p.closeListener.Notify() p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
} }
return 0, err return 0, err
} }

View File

@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
} }
return ebpf.NewProxyWrapper(w.ebpfProxy) return ebpf.NewProxyWrapper(w.ebpfProxy)
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -1,31 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
mtu uint16
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -11,6 +11,12 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
/*
RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
and rewrite the src address to the endpoint address.
With this logic can avoid the package loss from relayed connections.
*/
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func()) SetDisconnectListener(disconnected func())
} }

View File

@@ -3,54 +3,82 @@
package wgproxy package wgproxy
import ( import (
"context" "fmt"
"os" "net"
"testing"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
) )
func TestProxyCloseByRemoteConnEBPF(t *testing.T) { func seedProxies() ([]proxyInstance, error) {
if os.Getenv("GITHUB_ACTIONS") != "true" { pl := make([]proxyInstance, 0)
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil { if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err) return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
} }
defer func() { pEbpf := proxyInstance{
if err := ebpfProxy.Free(); err != nil { name: "ebpf kernel proxy",
t.Errorf("failed to free ebpf proxy: %s", err) proxy: ebpf.NewProxyWrapper(ebpfProxy),
} wgPort: 51831,
}() closeFn: ebpfProxy.Free,
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
} }
pl = append(pl, pEbpf)
for _, tt := range tests { pUDP := proxyInstance{
t.Run(tt.name, func(t *testing.T) { name: "udp kernel proxy",
relayedConn := newMockConn() proxy: udp.NewWGUDPProxy(51832, 1280),
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) wgPort: 51832,
if err != nil { closeFn: func() error { return nil },
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
} }
pl = append(pl, pUDP)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
} }

View File

@@ -0,0 +1,39 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,5 +1,3 @@
//go:build linux
package wgproxy package wgproxy
import ( import (
@@ -7,12 +5,9 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
os.Exit(code) os.Exit(code)
} }
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct { type mocConn struct {
closeChan chan struct{} closeChan chan struct{}
closed bool closed bool
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) { func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tests := []struct { tests, err := seedProxyForProxyCloseByRemoteConn()
name string if err != nil {
proxy Proxy t.Fatalf("error: %v", err)
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
},
} }
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) defer func() {
if err := ebpfProxy.Listen(); err != nil { _ = relayedConn.Close()
t.Fatalf("failed to initialize ebpf proxy: %s", err) }()
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn() relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}) })
} }
} }
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

View File

@@ -0,0 +1,50 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/util/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux && !android
package udp package udp
import ( import (
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
localWGListenPort int localWGListenPort int
mtu uint16 mtu uint16
remoteConn net.Conn remoteConn net.Conn
localConn net.Conn localConn net.Conn
ctx context.Context srcFakerConn *SrcFaker
cancel context.CancelFunc sendPkg func(data []byte) (int, error)
closeMu sync.Mutex ctx context.Context
closed bool cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex paused bool
paused bool pausedCond *sync.Cond
isStarted bool isStarted bool
closeListener *listener.CloseListener closeListener *listener.CloseListener
} }
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
mtu: mtu, mtu: mtu,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(), closeListener: listener.NewCloseListener(),
} }
return p return p
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn p.remoteConn = remoteConn
return err return err
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.sendPkg = p.localConn.Write
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted { if !p.isStarted {
p.isStarted = true p.isStarted = true
go p.proxyToRemote(p.ctx) go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx) go p.proxyToLocal(p.ctx)
} }
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
} }
// Pause pauses the proxy from receiving data from the remote peer // Pause pauses the proxy from receiving data from the remote peer
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
p.paused = true p.paused = true
p.pausedMu.Unlock() p.pausedCond.L.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
} }
// CloseConn close the localConn // CloseConn close the localConn
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
} }
func (p *WGUDPProxy) close() error { func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
p.cancel() p.cancel()
var result *multierror.Error p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return return
} }
p.pausedMu.Lock() p.pausedCond.L.Lock()
if p.paused { for p.paused {
p.pausedMu.Unlock() p.pausedCond.Wait()
continue
} }
_, err = p.sendPkg(buf[:n])
_, err = p.localConn.Write(buf[:n]) p.pausedCond.L.Unlock()
p.pausedMu.Unlock()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {

View File

@@ -0,0 +1,101 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -949,7 +949,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.LazyConnectionEnabled, e.config.LazyConnectionEnabled,
) )
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil { if err != nil {
// happens if management is unavailable for a long time. // happens if management is unavailable for a long time.
@@ -960,7 +959,7 @@ func (e *Engine) receiveManagementEvents() {
} }
log.Debugf("stopped receiving updates from Management Service") log.Debugf("stopped receiving updates from Management Service")
}() }()
log.Debugf("connecting to Management Service updates stream") log.Infof("connecting to Management Service updates stream")
} }
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error { func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {

View File

@@ -376,12 +376,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
wgProxy.Work() wgProxy.Work()
} }
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy) conn.handleConfigurationFailure(err, wgProxy)
return return
} }
wgConfigWorkaround() wgConfigWorkaround()
if conn.wgProxyRelay != nil {
conn.Log.Debugf("redirect packages from relayed conn to WireGuard")
conn.wgProxyRelay.RedirectAs(ep)
}
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
@@ -419,6 +426,7 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done() defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx) conn.workerRelay.EnableWgWatcher(conn.ctx)
}() }()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay conn.currentConnPriority = conntype.Relay
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())

View File

@@ -43,13 +43,6 @@ type OfferAnswer struct {
SessionID *ICESessionID SessionID *ICESessionID
} }
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
@@ -57,7 +50,7 @@ type Handshaker struct {
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
onNewOfferListeners []func(*OfferAnswer) onNewOfferListeners []*OfferListener
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.onNewOfferListeners = append(h.onNewOfferListeners, offer) l := NewOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
@@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue continue
} }
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
} }
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -0,0 +1,62 @@
package peer
import (
"sync"
)
type callbackFunc func(remoteOfferAnswer *OfferAnswer)
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
}
type OfferListener struct {
fn callbackFunc
running bool
latest *OfferAnswer
mu sync.Mutex
}
func NewOfferListener(fn callbackFunc) *OfferListener {
return &OfferListener{
fn: fn,
}
}
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock()
defer o.mu.Unlock()
// Store the latest offer
o.latest = remoteOfferAnswer
// If already running, the running goroutine will pick up this latest value
if o.running {
return
}
// Start processing
o.running = true
// Process in a goroutine to avoid blocking the caller
go func(remoteOfferAnswer *OfferAnswer) {
for {
o.fn(remoteOfferAnswer)
o.mu.Lock()
if o.latest == nil {
// No more work to do
o.running = false
o.mu.Unlock()
return
}
remoteOfferAnswer = o.latest
// Clear the latest to mark it as being processed
o.latest = nil
o.mu.Unlock()
}
}(remoteOfferAnswer)
}

View File

@@ -0,0 +1,39 @@
package peer
import (
"testing"
"time"
)
func Test_newOfferListener(t *testing.T) {
dummyOfferAnswer := &OfferAnswer{}
runChan := make(chan struct{}, 10)
longRunningFn := func(remoteOfferAnswer *OfferAnswer) {
time.Sleep(1 * time.Second)
runChan <- struct{}{}
}
hl := NewOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)
// Wait for exactly 2 callbacks
for i := 0; i < 2; i++ {
select {
case <-runChan:
case <-time.After(3 * time.Second):
t.Fatal("Timeout waiting for callback")
}
}
// Verify no additional callbacks happen
select {
case <-runChan:
t.Fatal("Unexpected additional callback")
case <-time.After(100 * time.Millisecond):
t.Log("Correctly received exactly 2 callbacks")
}
}

View File

@@ -101,6 +101,10 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
onDisconnectedFn() onDisconnectedFn()
return return
} }
if lastHandshake.IsZero() {
w.log.Infof("first wg handshake detected at: %s", handshake)
}
lastHandshake = *handshake lastHandshake = *handshake
resetTime := time.Until(handshake.Add(checkPeriod)) resetTime := time.Until(handshake.Add(checkPeriod))

View File

@@ -122,7 +122,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.agent = nil w.agent = nil
// todo consider to switch to Relay connection while establishing a new ICE connection
} }
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
@@ -410,7 +409,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
case ice.ConnectionStateConnected: case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected w.lastKnownState = ice.ConnectionStateConnected
return return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected: case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()

View File

@@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
@@ -180,6 +181,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
// GetInfoWithChecks retrieves and parses the system information with applied checks. // GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))
processCheckPaths := make([]string, 0) processCheckPaths := make([]string, 0)
for _, check := range checks { for _, check := range checks {
processCheckPaths = append(processCheckPaths, check.GetFiles()...) processCheckPaths = append(processCheckPaths, check.GetFiles()...)
@@ -189,10 +191,12 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("gathering process check information completed")
info := GetInfo(ctx) info := GetInfo(ctx)
info.Files = files info.Files = files
log.Debugf("all system information gathered successfully")
return info, nil return info, nil
} }