diff --git a/client/iface/bind/dual_stack_conn.go b/client/iface/bind/dual_stack_conn.go new file mode 100644 index 000000000..061016ecc --- /dev/null +++ b/client/iface/bind/dual_stack_conn.go @@ -0,0 +1,169 @@ +package bind + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" +) + +var ( + errNoIPv4Conn = errors.New("no IPv4 connection available") + errNoIPv6Conn = errors.New("no IPv6 connection available") + errInvalidAddr = errors.New("invalid address type") +) + +// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes +// to the appropriate connection based on the destination address. +// ReadFrom is not used in the hot path - ICEBind receives packets via +// BatchReader.ReadBatch() directly. This is only used by udpMux for sending. +type DualStackPacketConn struct { + ipv4Conn net.PacketConn + ipv6Conn net.PacketConn + + readFromWarn sync.Once +} + +// NewDualStackPacketConn creates a new dual-stack packet connection. +func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn { + return &DualStackPacketConn{ + ipv4Conn: ipv4Conn, + ipv6Conn: ipv6Conn, + } +} + +// ReadFrom reads from the available connection (preferring IPv4). +// NOTE: This method is NOT used in the data path. ICEBind receives packets via +// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient. +// This implementation exists only to satisfy the net.PacketConn interface for the udpMux, +// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom() +// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path. +func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + d.readFromWarn.Do(func() { + log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path") + }) + + if d.ipv4Conn != nil { + return d.ipv4Conn.ReadFrom(b) + } + if d.ipv6Conn != nil { + return d.ipv6Conn.ReadFrom(b) + } + return 0, nil, net.ErrClosed +} + +// WriteTo writes to the appropriate connection based on the address type. +func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + udpAddr, ok := addr.(*net.UDPAddr) + if !ok { + return 0, &net.OpError{ + Op: "write", + Net: "udp", + Addr: addr, + Err: errInvalidAddr, + } + } + + if udpAddr.IP.To4() == nil { + if d.ipv6Conn != nil { + return d.ipv6Conn.WriteTo(b, addr) + } + return 0, &net.OpError{ + Op: "write", + Net: "udp6", + Addr: addr, + Err: errNoIPv6Conn, + } + } + + if d.ipv4Conn != nil { + return d.ipv4Conn.WriteTo(b, addr) + } + return 0, &net.OpError{ + Op: "write", + Net: "udp4", + Addr: addr, + Err: errNoIPv4Conn, + } +} + +// Close closes both connections. +func (d *DualStackPacketConn) Close() error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.Close(); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.Close(); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// LocalAddr returns the local address of the IPv4 connection if available, +// otherwise the IPv6 connection. +func (d *DualStackPacketConn) LocalAddr() net.Addr { + if d.ipv4Conn != nil { + return d.ipv4Conn.LocalAddr() + } + if d.ipv6Conn != nil { + return d.ipv6Conn.LocalAddr() + } + return nil +} + +// SetDeadline sets the deadline for both connections. +func (d *DualStackPacketConn) SetDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// SetReadDeadline sets the read deadline for both connections. +func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetReadDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetReadDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} + +// SetWriteDeadline sets the write deadline for both connections. +func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error { + var result *multierror.Error + if d.ipv4Conn != nil { + if err := d.ipv4Conn.SetWriteDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + if d.ipv6Conn != nil { + if err := d.ipv6Conn.SetWriteDeadline(t); err != nil { + result = multierror.Append(result, err) + } + } + return nberrors.FormatErrorOrNil(result) +} diff --git a/client/iface/bind/dual_stack_conn_bench_test.go b/client/iface/bind/dual_stack_conn_bench_test.go new file mode 100644 index 000000000..940c44966 --- /dev/null +++ b/client/iface/bind/dual_stack_conn_bench_test.go @@ -0,0 +1,119 @@ +package bind + +import ( + "net" + "testing" +) + +var ( + ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345} + ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345} + payload = make([]byte, 1200) +) + +func BenchmarkWriteTo_DirectUDPConn(b *testing.B) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = conn.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) { + conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + ds := NewDualStackPacketConn(conn, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) { + conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn.Close() + + ds := NewDualStackPacketConn(nil, conn) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv6Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv4Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, ipv6Addr) + } +} + +func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) { + conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + b.Fatal(err) + } + defer conn4.Close() + + conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + b.Skipf("IPv6 not available: %v", err) + } + defer conn6.Close() + + ds := NewDualStackPacketConn(conn4, conn6) + addrs := []net.Addr{ipv4Addr, ipv6Addr} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ds.WriteTo(payload, addrs[i&1]) + } +} diff --git a/client/iface/bind/dual_stack_conn_test.go b/client/iface/bind/dual_stack_conn_test.go new file mode 100644 index 000000000..3007d907f --- /dev/null +++ b/client/iface/bind/dual_stack_conn_test.go @@ -0,0 +1,191 @@ +package bind + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) { + ipv4Conn := &mockPacketConn{network: "udp4"} + ipv6Conn := &mockPacketConn{network: "udp6"} + dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn) + + tests := []struct { + name string + addr *net.UDPAddr + wantSocket string + }{ + { + name: "IPv4 address", + addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv6 address", + addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + wantSocket: "udp6", + }, + { + name: "IPv4-mapped IPv6 goes to IPv4", + addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv4 loopback", + addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}, + wantSocket: "udp4", + }, + { + name: "IPv6 loopback", + addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234}, + wantSocket: "udp6", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ipv4Conn.writeCount = 0 + ipv6Conn.writeCount = 0 + + n, err := dualStack.WriteTo([]byte("test"), tt.addr) + require.NoError(t, err) + assert.Equal(t, 4, n) + + if tt.wantSocket == "udp4" { + assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4") + assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6") + } else { + assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4") + assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6") + } + }) + } +} + +func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) { + dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil) + + // IPv4 works + _, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}) + require.NoError(t, err) + + // IPv6 fails + _, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no IPv6 connection") +} + +func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) { + dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"}) + + // IPv6 works + _, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}) + require.NoError(t, err) + + // IPv4 fails + _, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no IPv4 connection") +} + +// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom +// only reads from one socket (IPv4 preferred). This is fine because the actual +// receive path uses wireguard-go's BatchReader directly, not ReadFrom. +func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) { + ipv4Conn := &mockPacketConn{ + network: "udp4", + readData: []byte("from ipv4"), + readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234}, + } + ipv6Conn := &mockPacketConn{ + network: "udp6", + readData: []byte("from ipv6"), + readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234}, + } + + dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn) + + buf := make([]byte, 100) + n, addr, err := dualStack.ReadFrom(buf) + + require.NoError(t, err) + // reads from IPv4 (preferred) - this is expected behavior + assert.Equal(t, "from ipv4", string(buf[:n])) + assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String()) +} + +func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) { + ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820} + ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820} + + tests := []struct { + name string + ipv4 net.PacketConn + ipv6 net.PacketConn + wantAddr net.Addr + }{ + { + name: "both available returns IPv4", + ipv4: &mockPacketConn{localAddr: ipv4Addr}, + ipv6: &mockPacketConn{localAddr: ipv6Addr}, + wantAddr: ipv4Addr, + }, + { + name: "IPv4 only", + ipv4: &mockPacketConn{localAddr: ipv4Addr}, + ipv6: nil, + wantAddr: ipv4Addr, + }, + { + name: "IPv6 only", + ipv4: nil, + ipv6: &mockPacketConn{localAddr: ipv6Addr}, + wantAddr: ipv6Addr, + }, + { + name: "neither returns nil", + ipv4: nil, + ipv6: nil, + wantAddr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6) + assert.Equal(t, tt.wantAddr, dualStack.LocalAddr()) + }) + } +} + +// mock + +type mockPacketConn struct { + network string + writeCount int + readData []byte + readAddr net.Addr + localAddr net.Addr +} + +func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { + if m.readData != nil { + return copy(b, m.readData), m.readAddr, nil + } + return 0, nil, nil +} + +func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + m.writeCount++ + return len(b), nil +} + +func (m *mockPacketConn) Close() error { return nil } +func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr } +func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil } +func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 0957d2dd5..bf79ecd79 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -14,7 +14,6 @@ import ( "github.com/pion/stun/v3" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" @@ -28,22 +27,7 @@ type receiverCreator struct { } func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc { - if ipv4PC, ok := pc.(*ipv4.PacketConn); ok { - return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool) - } - // IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the - // wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6. - return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - buf := bufs[0] - size, ep, err := conn.ReadFromUDPAddrPort(buf) - if err != nil { - return 0, err - } - sizes[0] = size - stdEp := &wgConn.StdNetEndpoint{AddrPort: ep} - eps[0] = stdEp - return 1, nil - } + return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool) } // ICEBind is a bind implementation with two main features: @@ -73,6 +57,8 @@ type ICEBind struct { muUDPMux sync.Mutex udpMux *udpmux.UniversalUDPMuxDefault + ipv4Conn *net.UDPConn + ipv6Conn *net.UDPConn } func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { @@ -118,6 +104,12 @@ func (s *ICEBind) Close() error { close(s.closedChan) + s.muUDPMux.Lock() + s.ipv4Conn = nil + s.ipv6Conn = nil + s.udpMux = nil + s.muUDPMux.Unlock() + return s.StdNetBind.Close() } @@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } -func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { +func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() - s.udpMux = udpmux.NewUniversalUDPMuxDefault( - udpmux.UniversalUDPMuxParams{ - UDPConn: nbnet.WrapPacketConn(conn), - Net: s.transportNet, - FilterFn: s.filterFn, - WGAddress: s.address, - MTU: s.mtu, - }, - ) + // Detect IPv4 vs IPv6 from connection's local address + if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil { + s.ipv4Conn = conn + } else { + s.ipv6Conn = conn + } + s.createOrUpdateMux() + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { msgs := getMessages(msgsPool) for i := range bufs { @@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } defer putMessages(msgs, msgsPool) + var numMsgs int if runtime.GOOS == "linux" || runtime.GOOS == "android" { if rxOffload { readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams) - //nolint - numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0) + //nolint:staticcheck + _, err = pc.ReadBatch((*msgs)[readAt:], 0) if err != nil { return 0, err } @@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r } numMsgs = 1 } + for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] // todo: handle err - ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr) - if ok { + if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok { continue } sizes[i] = msg.N @@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r } } +// createOrUpdateMux creates or updates the UDP mux with the available connections. +// Must be called with muUDPMux held. +func (s *ICEBind) createOrUpdateMux() { + var muxConn net.PacketConn + + switch { + case s.ipv4Conn != nil && s.ipv6Conn != nil: + muxConn = NewDualStackPacketConn( + nbnet.WrapPacketConn(s.ipv4Conn), + nbnet.WrapPacketConn(s.ipv6Conn), + ) + case s.ipv4Conn != nil: + muxConn = nbnet.WrapPacketConn(s.ipv4Conn) + case s.ipv6Conn != nil: + muxConn = nbnet.WrapPacketConn(s.ipv6Conn) + default: + return + } + + // Don't close the old mux - it doesn't own the underlying connections. + // The sockets are managed by WireGuard's StdNetBind, not by us. + s.udpMux = udpmux.NewUniversalUDPMuxDefault( + udpmux.UniversalUDPMuxParams{ + UDPConn: muxConn, + Net: s.transportNet, + FilterFn: s.filterFn, + WGAddress: s.address, + MTU: s.mtu, + }, + ) +} + func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) { for i := range buffers { if !stun.IsMessage(buffers[i]) { @@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) return true, err } - muxErr := s.udpMux.HandleSTUNMessage(msg, addr) - if muxErr != nil { - log.Warnf("failed to handle STUN packet") + s.muUDPMux.Lock() + mux := s.udpMux + s.muUDPMux.Unlock() + + if mux != nil { + if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil { + log.Warnf("failed to handle STUN packet: %v", muxErr) + } } buffers[i] = []byte{} diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go new file mode 100644 index 000000000..1fdd955c9 --- /dev/null +++ b/client/iface/bind/ice_bind_test.go @@ -0,0 +1,324 @@ +package bind + +import ( + "fmt" + "net" + "net/netip" + "sync" + "testing" + "time" + + "github.com/pion/transport/v3/stdnet" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + + "github.com/netbirdio/netbird/client/iface/wgaddr" +) + +func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) { + iceBind := setupICEBind(t) + + ipv4Conn, ipv6Conn := createDualStackConns(t) + defer ipv4Conn.Close() + defer ipv6Conn.Close() + + rc := receiverCreator{iceBind} + pool := createMsgPool() + + // Simulate wireguard-go calling CreateReceiverFn for IPv4 + ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool) + require.NotNil(t, ipv4RecvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection") + assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet") + assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection") + iceBind.muUDPMux.Unlock() + + // Simulate wireguard-go calling CreateReceiverFn for IPv6 + ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool) + require.NotNil(t, ipv6RecvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection") + assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection") + assert.NotNil(t, iceBind.udpMux, "mux should still exist") + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +func TestICEBind_WorksWithIPv4Only(t *testing.T) { + iceBind := setupICEBind(t) + + ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + defer ipv4Conn.Close() + + rc := receiverCreator{iceBind} + recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool()) + require.NotNil(t, recvFn) + + iceBind.muUDPMux.Lock() + assert.NotNil(t, iceBind.ipv4Conn) + assert.Nil(t, iceBind.ipv6Conn) + assert.NotNil(t, iceBind.udpMux) + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +func TestICEBind_WorksWithIPv6Only(t *testing.T) { + iceBind := setupICEBind(t) + + ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Conn.Close() + + rc := receiverCreator{iceBind} + recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool()) + require.NotNil(t, recvFn) + + iceBind.muUDPMux.Lock() + assert.Nil(t, iceBind.ipv4Conn) + assert.NotNil(t, iceBind.ipv6Conn) + assert.NotNil(t, iceBind.udpMux) + iceBind.muUDPMux.Unlock() + + mux, err := iceBind.GetICEMux() + require.NoError(t, err) + require.NotNil(t, mux) +} + +// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate +// with peers on different address families through the same DualStackPacketConn. +func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) { + // two "remote peers" listening on different address families + ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Peer.Close() + + ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Peer.Close() + + // our local dual-stack connection + ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Local.Close() + + ipv6Local := listenUDP(t, "udp6", "[::1]:0") + defer ipv6Local.Close() + + dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local) + + // send to both peers + _, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr()) + require.NoError(t, err) + + _, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr()) + require.NoError(t, err) + + // verify IPv4 peer got its packet from the IPv4 socket + buf := make([]byte, 100) + _ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second)) + n, addr, err := ipv4Peer.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, "to-ipv4", string(buf[:n])) + assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port) + + // verify IPv6 peer got its packet from the IPv6 socket + _ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second)) + n, addr, err = ipv6Peer.ReadFrom(buf) + require.NoError(t, err) + assert.Equal(t, "to-ipv6", string(buf[:n])) + assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port) +} + +// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4 +// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets, +// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP. +func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { + ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Peer.Close() + + ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0}) + if err != nil { + t.Skipf("IPv6 not available: %v", err) + } + defer ipv6Peer.Close() + + ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0") + defer ipv4Local.Close() + + ipv6Local := listenUDP(t, "udp6", "[::1]:0") + defer ipv6Local.Close() + + dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local) + + const packetsPerFamily = 500 + + ipv4Received := make(chan string, packetsPerFamily) + ipv6Received := make(chan string, packetsPerFamily) + + startGate := make(chan struct{}) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + for i := 0; i < packetsPerFamily; i++ { + n, _, err := ipv4Peer.ReadFrom(buf) + if err != nil { + return + } + ipv4Received <- string(buf[:n]) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + buf := make([]byte, 100) + for i := 0; i < packetsPerFamily; i++ { + n, _, err := ipv6Peer.ReadFrom(buf) + if err != nil { + return + } + ipv6Received <- string(buf[:n]) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-startGate + for i := 0; i < packetsPerFamily; i++ { + _, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr()) + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + <-startGate + for i := 0; i < packetsPerFamily; i++ { + _, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr()) + } + }() + + close(startGate) + + time.AfterFunc(5*time.Second, func() { + _ = ipv4Peer.SetReadDeadline(time.Now()) + _ = ipv6Peer.SetReadDeadline(time.Now()) + }) + + wg.Wait() + close(ipv4Received) + close(ipv6Received) + + ipv4Count := 0 + for pkt := range ipv4Received { + require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt) + ipv4Count++ + } + + ipv6Count := 0 + for pkt := range ipv6Received { + require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt) + ipv6Count++ + } + + assert.Equal(t, packetsPerFamily, ipv4Count) + assert.Equal(t, packetsPerFamily, ipv6Count) +} + +func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { + tests := []struct { + name string + network string + addr string + wantIPv4 bool + }{ + {"IPv4 any", "udp4", "0.0.0.0:0", true}, + {"IPv4 loopback", "udp4", "127.0.0.1:0", true}, + {"IPv6 any", "udp6", "[::]:0", false}, + {"IPv6 loopback", "udp6", "[::1]:0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, err := net.ResolveUDPAddr(tt.network, tt.addr) + require.NoError(t, err) + + conn, err := net.ListenUDP(tt.network, addr) + if err != nil { + t.Skipf("%s not available: %v", tt.network, err) + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + isIPv4 := localAddr.IP.To4() != nil + assert.Equal(t, tt.wantIPv4, isIPv4) + }) + } +} + +// helpers + +func setupICEBind(t *testing.T) *ICEBind { + t.Helper() + transportNet, err := stdnet.NewNet() + require.NoError(t, err) + + address := wgaddr.Address{ + IP: netip.MustParseAddr("100.64.0.1"), + Network: netip.MustParsePrefix("100.64.0.0/10"), + } + return NewICEBind(transportNet, nil, address, 1280) +} + +func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) { + t.Helper() + ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + require.NoError(t, err) + + ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + if err != nil { + ipv4Conn.Close() + t.Skipf("IPv6 not available: %v", err) + } + return ipv4Conn, ipv6Conn +} + +func createMsgPool() *sync.Pool { + return &sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, 1) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, 40) + } + return &msgs + }, + } +} + +func listenUDP(t *testing.T, network, addr string) *net.UDPConn { + t.Helper() + udpAddr, err := net.ResolveUDPAddr(network, addr) + require.NoError(t, err) + conn, err := net.ListenUDP(network, udpAddr) + require.NoError(t, err) + return conn +} diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 840fc9241..b6b9d2cf4 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "strconv" "sync" "time" @@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent RosenpassAddr: remoteOfferAnswer.RosenpassAddr, LocalIceCandidateType: pair.Local.Type().String(), RemoteIceCandidateType: pair.Remote.Type().String(), - LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), - RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), + LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())), + RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())), Relayed: isRelayed(pair), RelayedOnLocal: isRelayCandidate(pair.Local), } @@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) - addrString := pair.Remote.Address() - parsed, err := netip.ParseAddr(addrString) - if (err == nil) && (parsed.Is6()) { - addrString = fmt.Sprintf("[%s]", addrString) - //IPv6 Literals need to be wrapped in brackets for Resolve*Addr() - } - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort)) + addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort))) if err != nil { w.log.Warnf("got an error while resolving the udp address, err: %s", err) return @@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, } } +func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) { + sessionID := w.SessionID() + stats := agent.GetCandidatePairsStats() + localCandidates, _ := agent.GetLocalCandidates() + remoteCandidates, _ := agent.GetRemoteCandidates() + + localMap := make(map[string]ice.Candidate) + for _, c := range localCandidates { + localMap[c.ID()] = c + } + remoteMap := make(map[string]ice.Candidate) + for _, c := range remoteCandidates { + remoteMap[c.ID()] = c + } + + for _, stat := range stats { + if stat.State == ice.CandidatePairStateSucceeded { + local, lok := localMap[stat.LocalCandidateID] + remote, rok := remoteMap[stat.RemoteCandidateID] + if !lok || !rok { + continue + } + w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms", + sessionID, + local.NetworkType(), local.Type(), local.Address(), + remote.NetworkType(), remote.Type(), remote.Address(), + stat.CurrentRoundTripTime*1000) + } + } +} + func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) { return func(state ice.ConnectionState) { w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) switch state { case ice.ConnectionStateConnected: w.lastKnownState = ice.ConnectionStateConnected + w.logSuccessfulPaths(agent) return case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed: // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to diff --git a/sharedsock/sock_linux.go b/sharedsock/sock_linux.go index bc2d4d1be..523beb32b 100644 --- a/sharedsock/sock_linux.go +++ b/sharedsock/sock_linux.go @@ -154,9 +154,20 @@ func (s *SharedSocket) updateRouter() { } } -// LocalAddr returns an IPv4 address using the supplied port +// LocalAddr returns the local address, preferring IPv4 for backward compatibility. func (s *SharedSocket) LocalAddr() net.Addr { - // todo check impact on ipv6 discovery + if s.conn4 != nil { + return &net.UDPAddr{ + IP: net.IPv4zero, + Port: s.port, + } + } + if s.conn6 != nil { + return &net.UDPAddr{ + IP: net.IPv6zero, + Port: s.port, + } + } return &net.UDPAddr{ IP: net.IPv4zero, Port: s.port,