mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
[client] Add IPv6 support to usersace bind (#5147)
This commit is contained in:
169
client/iface/bind/dual_stack_conn.go
Normal file
169
client/iface/bind/dual_stack_conn.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNoIPv4Conn = errors.New("no IPv4 connection available")
|
||||||
|
errNoIPv6Conn = errors.New("no IPv6 connection available")
|
||||||
|
errInvalidAddr = errors.New("invalid address type")
|
||||||
|
)
|
||||||
|
|
||||||
|
// DualStackPacketConn wraps IPv4 and IPv6 UDP connections and routes writes
|
||||||
|
// to the appropriate connection based on the destination address.
|
||||||
|
// ReadFrom is not used in the hot path - ICEBind receives packets via
|
||||||
|
// BatchReader.ReadBatch() directly. This is only used by udpMux for sending.
|
||||||
|
type DualStackPacketConn struct {
|
||||||
|
ipv4Conn net.PacketConn
|
||||||
|
ipv6Conn net.PacketConn
|
||||||
|
|
||||||
|
readFromWarn sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDualStackPacketConn creates a new dual-stack packet connection.
|
||||||
|
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
|
||||||
|
return &DualStackPacketConn{
|
||||||
|
ipv4Conn: ipv4Conn,
|
||||||
|
ipv6Conn: ipv6Conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFrom reads from the available connection (preferring IPv4).
|
||||||
|
// NOTE: This method is NOT used in the data path. ICEBind receives packets via
|
||||||
|
// BatchReader.ReadBatch() directly for both IPv4 and IPv6, which is much more efficient.
|
||||||
|
// This implementation exists only to satisfy the net.PacketConn interface for the udpMux,
|
||||||
|
// but the udpMux only uses WriteTo() for sending STUN responses - it never calls ReadFrom()
|
||||||
|
// because STUN packets are filtered and forwarded via HandleSTUNMessage() from the receive path.
|
||||||
|
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
|
d.readFromWarn.Do(func() {
|
||||||
|
log.Warn("DualStackPacketConn.ReadFrom called - this is unexpected and may indicate an inefficient code path")
|
||||||
|
})
|
||||||
|
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
return d.ipv4Conn.ReadFrom(b)
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
return d.ipv6Conn.ReadFrom(b)
|
||||||
|
}
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes to the appropriate connection based on the address type.
|
||||||
|
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
return 0, &net.OpError{
|
||||||
|
Op: "write",
|
||||||
|
Net: "udp",
|
||||||
|
Addr: addr,
|
||||||
|
Err: errInvalidAddr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if udpAddr.IP.To4() == nil {
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
return d.ipv6Conn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
return 0, &net.OpError{
|
||||||
|
Op: "write",
|
||||||
|
Net: "udp6",
|
||||||
|
Addr: addr,
|
||||||
|
Err: errNoIPv6Conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
return d.ipv4Conn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
return 0, &net.OpError{
|
||||||
|
Op: "write",
|
||||||
|
Net: "udp4",
|
||||||
|
Addr: addr,
|
||||||
|
Err: errNoIPv4Conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes both connections.
|
||||||
|
func (d *DualStackPacketConn) Close() error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.Close(); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr returns the local address of the IPv4 connection if available,
|
||||||
|
// otherwise the IPv6 connection.
|
||||||
|
func (d *DualStackPacketConn) LocalAddr() net.Addr {
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
return d.ipv4Conn.LocalAddr()
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
return d.ipv6Conn.LocalAddr()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDeadline sets the deadline for both connections.
|
||||||
|
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.SetDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.SetDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the read deadline for both connections.
|
||||||
|
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the write deadline for both connections.
|
||||||
|
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
var result *multierror.Error
|
||||||
|
if d.ipv4Conn != nil {
|
||||||
|
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if d.ipv6Conn != nil {
|
||||||
|
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
|
||||||
|
result = multierror.Append(result, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
119
client/iface/bind/dual_stack_conn_bench_test.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ipv4Addr = &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 12345}
|
||||||
|
ipv6Addr = &net.UDPAddr{IP: net.ParseIP("::1"), Port: 12345}
|
||||||
|
payload = make([]byte, 1200)
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DirectUDPConn(b *testing.B) {
|
||||||
|
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = conn.WriteTo(payload, ipv4Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_IPv4Only(b *testing.B) {
|
||||||
|
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn, nil)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_IPv6Only(b *testing.B) {
|
||||||
|
conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(nil, conn)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_Both_IPv4Traffic(b *testing.B) {
|
||||||
|
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn4.Close()
|
||||||
|
|
||||||
|
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn6.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn4, conn6)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv4Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_Both_IPv6Traffic(b *testing.B) {
|
||||||
|
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn4.Close()
|
||||||
|
|
||||||
|
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn6.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn4, conn6)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, ipv6Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWriteTo_DualStack_Both_MixedTraffic(b *testing.B) {
|
||||||
|
conn4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn4.Close()
|
||||||
|
|
||||||
|
conn6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
b.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer conn6.Close()
|
||||||
|
|
||||||
|
ds := NewDualStackPacketConn(conn4, conn6)
|
||||||
|
addrs := []net.Addr{ipv4Addr, ipv6Addr}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ds.WriteTo(payload, addrs[i&1])
|
||||||
|
}
|
||||||
|
}
|
||||||
191
client/iface/bind/dual_stack_conn_test.go
Normal file
191
client/iface/bind/dual_stack_conn_test.go
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_RoutesWritesToCorrectSocket(t *testing.T) {
|
||||||
|
ipv4Conn := &mockPacketConn{network: "udp4"}
|
||||||
|
ipv6Conn := &mockPacketConn{network: "udp6"}
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr *net.UDPAddr
|
||||||
|
wantSocket string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "IPv4 address",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||||
|
wantSocket: "udp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||||
|
wantSocket: "udp6",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4-mapped IPv6 goes to IPv4",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("::ffff:192.168.1.1"), Port: 1234},
|
||||||
|
wantSocket: "udp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 loopback",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
||||||
|
wantSocket: "udp4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 loopback",
|
||||||
|
addr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: 1234},
|
||||||
|
wantSocket: "udp6",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ipv4Conn.writeCount = 0
|
||||||
|
ipv6Conn.writeCount = 0
|
||||||
|
|
||||||
|
n, err := dualStack.WriteTo([]byte("test"), tt.addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 4, n)
|
||||||
|
|
||||||
|
if tt.wantSocket == "udp4" {
|
||||||
|
assert.Equal(t, 1, ipv4Conn.writeCount, "expected write to IPv4")
|
||||||
|
assert.Equal(t, 0, ipv6Conn.writeCount, "expected no write to IPv6")
|
||||||
|
} else {
|
||||||
|
assert.Equal(t, 0, ipv4Conn.writeCount, "expected no write to IPv4")
|
||||||
|
assert.Equal(t, 1, ipv6Conn.writeCount, "expected write to IPv6")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_IPv4OnlyRejectsIPv6(t *testing.T) {
|
||||||
|
dualStack := NewDualStackPacketConn(&mockPacketConn{network: "udp4"}, nil)
|
||||||
|
|
||||||
|
// IPv4 works
|
||||||
|
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// IPv6 fails
|
||||||
|
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no IPv6 connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_IPv6OnlyRejectsIPv4(t *testing.T) {
|
||||||
|
dualStack := NewDualStackPacketConn(nil, &mockPacketConn{network: "udp6"})
|
||||||
|
|
||||||
|
// IPv6 works
|
||||||
|
_, err := dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// IPv4 fails
|
||||||
|
_, err = dualStack.WriteTo([]byte("test"), &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no IPv4 connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDualStackPacketConn_ReadFromIsNotUsedInHotPath documents that ReadFrom
|
||||||
|
// only reads from one socket (IPv4 preferred). This is fine because the actual
|
||||||
|
// receive path uses wireguard-go's BatchReader directly, not ReadFrom.
|
||||||
|
func TestDualStackPacketConn_ReadFromIsNotUsedInHotPath(t *testing.T) {
|
||||||
|
ipv4Conn := &mockPacketConn{
|
||||||
|
network: "udp4",
|
||||||
|
readData: []byte("from ipv4"),
|
||||||
|
readAddr: &net.UDPAddr{IP: net.ParseIP("192.168.1.1"), Port: 1234},
|
||||||
|
}
|
||||||
|
ipv6Conn := &mockPacketConn{
|
||||||
|
network: "udp6",
|
||||||
|
readData: []byte("from ipv6"),
|
||||||
|
readAddr: &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 1234},
|
||||||
|
}
|
||||||
|
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Conn, ipv6Conn)
|
||||||
|
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
n, addr, err := dualStack.ReadFrom(buf)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
// reads from IPv4 (preferred) - this is expected behavior
|
||||||
|
assert.Equal(t, "from ipv4", string(buf[:n]))
|
||||||
|
assert.Equal(t, "192.168.1.1", addr.(*net.UDPAddr).IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDualStackPacketConn_LocalAddrPrefersIPv4(t *testing.T) {
|
||||||
|
ipv4Addr := &net.UDPAddr{IP: net.ParseIP("0.0.0.0"), Port: 51820}
|
||||||
|
ipv6Addr := &net.UDPAddr{IP: net.ParseIP("::"), Port: 51820}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ipv4 net.PacketConn
|
||||||
|
ipv6 net.PacketConn
|
||||||
|
wantAddr net.Addr
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "both available returns IPv4",
|
||||||
|
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||||
|
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||||
|
wantAddr: ipv4Addr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv4 only",
|
||||||
|
ipv4: &mockPacketConn{localAddr: ipv4Addr},
|
||||||
|
ipv6: nil,
|
||||||
|
wantAddr: ipv4Addr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 only",
|
||||||
|
ipv4: nil,
|
||||||
|
ipv6: &mockPacketConn{localAddr: ipv6Addr},
|
||||||
|
wantAddr: ipv6Addr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "neither returns nil",
|
||||||
|
ipv4: nil,
|
||||||
|
ipv6: nil,
|
||||||
|
wantAddr: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
dualStack := NewDualStackPacketConn(tt.ipv4, tt.ipv6)
|
||||||
|
assert.Equal(t, tt.wantAddr, dualStack.LocalAddr())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mock
|
||||||
|
|
||||||
|
type mockPacketConn struct {
|
||||||
|
network string
|
||||||
|
writeCount int
|
||||||
|
readData []byte
|
||||||
|
readAddr net.Addr
|
||||||
|
localAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
|
||||||
|
if m.readData != nil {
|
||||||
|
return copy(b, m.readData), m.readAddr, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
m.writeCount++
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockPacketConn) Close() error { return nil }
|
||||||
|
func (m *mockPacketConn) LocalAddr() net.Addr { return m.localAddr }
|
||||||
|
func (m *mockPacketConn) SetDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockPacketConn) SetReadDeadline(t time.Time) error { return nil }
|
||||||
|
func (m *mockPacketConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||||
@@ -14,7 +14,6 @@ import (
|
|||||||
"github.com/pion/stun/v3"
|
"github.com/pion/stun/v3"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
@@ -28,22 +27,7 @@ type receiverCreator struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
|
||||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
|
||||||
}
|
|
||||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
|
||||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
|
||||||
buf := bufs[0]
|
|
||||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
sizes[0] = size
|
|
||||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
|
||||||
eps[0] = stdEp
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICEBind is a bind implementation with two main features:
|
// ICEBind is a bind implementation with two main features:
|
||||||
@@ -73,6 +57,8 @@ type ICEBind struct {
|
|||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *udpmux.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
|
ipv4Conn *net.UDPConn
|
||||||
|
ipv6Conn *net.UDPConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||||
@@ -118,6 +104,12 @@ func (s *ICEBind) Close() error {
|
|||||||
|
|
||||||
close(s.closedChan)
|
close(s.closedChan)
|
||||||
|
|
||||||
|
s.muUDPMux.Lock()
|
||||||
|
s.ipv4Conn = nil
|
||||||
|
s.ipv6Conn = nil
|
||||||
|
s.udpMux = nil
|
||||||
|
s.muUDPMux.Unlock()
|
||||||
|
|
||||||
return s.StdNetBind.Close()
|
return s.StdNetBind.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,19 +167,18 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
defer s.muUDPMux.Unlock()
|
defer s.muUDPMux.Unlock()
|
||||||
|
|
||||||
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
// Detect IPv4 vs IPv6 from connection's local address
|
||||||
udpmux.UniversalUDPMuxParams{
|
if localAddr := conn.LocalAddr().(*net.UDPAddr); localAddr.IP.To4() != nil {
|
||||||
UDPConn: nbnet.WrapPacketConn(conn),
|
s.ipv4Conn = conn
|
||||||
Net: s.transportNet,
|
} else {
|
||||||
FilterFn: s.filterFn,
|
s.ipv6Conn = conn
|
||||||
WGAddress: s.address,
|
}
|
||||||
MTU: s.mtu,
|
s.createOrUpdateMux()
|
||||||
},
|
|
||||||
)
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
msgs := getMessages(msgsPool)
|
msgs := getMessages(msgsPool)
|
||||||
for i := range bufs {
|
for i := range bufs {
|
||||||
@@ -195,12 +186,13 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||||
}
|
}
|
||||||
defer putMessages(msgs, msgsPool)
|
defer putMessages(msgs, msgsPool)
|
||||||
|
|
||||||
var numMsgs int
|
var numMsgs int
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
if rxOffload {
|
if rxOffload {
|
||||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
|
||||||
//nolint
|
//nolint:staticcheck
|
||||||
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
_, err = pc.ReadBatch((*msgs)[readAt:], 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -222,12 +214,12 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
}
|
}
|
||||||
numMsgs = 1
|
numMsgs = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
for i := 0; i < numMsgs; i++ {
|
||||||
msg := &(*msgs)[i]
|
msg := &(*msgs)[i]
|
||||||
|
|
||||||
// todo: handle err
|
// todo: handle err
|
||||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
if ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr); ok {
|
||||||
if ok {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
sizes[i] = msg.N
|
sizes[i] = msg.N
|
||||||
@@ -248,6 +240,38 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createOrUpdateMux creates or updates the UDP mux with the available connections.
|
||||||
|
// Must be called with muUDPMux held.
|
||||||
|
func (s *ICEBind) createOrUpdateMux() {
|
||||||
|
var muxConn net.PacketConn
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case s.ipv4Conn != nil && s.ipv6Conn != nil:
|
||||||
|
muxConn = NewDualStackPacketConn(
|
||||||
|
nbnet.WrapPacketConn(s.ipv4Conn),
|
||||||
|
nbnet.WrapPacketConn(s.ipv6Conn),
|
||||||
|
)
|
||||||
|
case s.ipv4Conn != nil:
|
||||||
|
muxConn = nbnet.WrapPacketConn(s.ipv4Conn)
|
||||||
|
case s.ipv6Conn != nil:
|
||||||
|
muxConn = nbnet.WrapPacketConn(s.ipv6Conn)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't close the old mux - it doesn't own the underlying connections.
|
||||||
|
// The sockets are managed by WireGuard's StdNetBind, not by us.
|
||||||
|
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||||
|
udpmux.UniversalUDPMuxParams{
|
||||||
|
UDPConn: muxConn,
|
||||||
|
Net: s.transportNet,
|
||||||
|
FilterFn: s.filterFn,
|
||||||
|
WGAddress: s.address,
|
||||||
|
MTU: s.mtu,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||||
for i := range buffers {
|
for i := range buffers {
|
||||||
if !stun.IsMessage(buffers[i]) {
|
if !stun.IsMessage(buffers[i]) {
|
||||||
@@ -260,9 +284,14 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
|
|||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
|
s.muUDPMux.Lock()
|
||||||
if muxErr != nil {
|
mux := s.udpMux
|
||||||
log.Warnf("failed to handle STUN packet")
|
s.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
if mux != nil {
|
||||||
|
if muxErr := mux.HandleSTUNMessage(msg, addr); muxErr != nil {
|
||||||
|
log.Warnf("failed to handle STUN packet: %v", muxErr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
buffers[i] = []byte{}
|
buffers[i] = []byte{}
|
||||||
|
|||||||
324
client/iface/bind/ice_bind_test.go
Normal file
324
client/iface/bind/ice_bind_test.go
Normal file
@@ -0,0 +1,324 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3/stdnet"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestICEBind_CreatesReceiverForBothIPv4AndIPv6(t *testing.T) {
|
||||||
|
iceBind := setupICEBind(t)
|
||||||
|
|
||||||
|
ipv4Conn, ipv6Conn := createDualStackConns(t)
|
||||||
|
defer ipv4Conn.Close()
|
||||||
|
defer ipv6Conn.Close()
|
||||||
|
|
||||||
|
rc := receiverCreator{iceBind}
|
||||||
|
pool := createMsgPool()
|
||||||
|
|
||||||
|
// Simulate wireguard-go calling CreateReceiverFn for IPv4
|
||||||
|
ipv4RecvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, pool)
|
||||||
|
require.NotNil(t, ipv4RecvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.NotNil(t, iceBind.ipv4Conn, "should store IPv4 connection")
|
||||||
|
assert.Nil(t, iceBind.ipv6Conn, "IPv6 not added yet")
|
||||||
|
assert.NotNil(t, iceBind.udpMux, "mux should be created after first connection")
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
// Simulate wireguard-go calling CreateReceiverFn for IPv6
|
||||||
|
ipv6RecvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, pool)
|
||||||
|
require.NotNil(t, ipv6RecvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.NotNil(t, iceBind.ipv4Conn, "should still have IPv4 connection")
|
||||||
|
assert.NotNil(t, iceBind.ipv6Conn, "should now have IPv6 connection")
|
||||||
|
assert.NotNil(t, iceBind.udpMux, "mux should still exist")
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
mux, err := iceBind.GetICEMux()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICEBind_WorksWithIPv4Only(t *testing.T) {
|
||||||
|
iceBind := setupICEBind(t)
|
||||||
|
|
||||||
|
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ipv4Conn.Close()
|
||||||
|
|
||||||
|
rc := receiverCreator{iceBind}
|
||||||
|
recvFn := rc.CreateReceiverFn(ipv4.NewPacketConn(ipv4Conn), ipv4Conn, false, createMsgPool())
|
||||||
|
require.NotNil(t, recvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.NotNil(t, iceBind.ipv4Conn)
|
||||||
|
assert.Nil(t, iceBind.ipv6Conn)
|
||||||
|
assert.NotNil(t, iceBind.udpMux)
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
mux, err := iceBind.GetICEMux()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICEBind_WorksWithIPv6Only(t *testing.T) {
|
||||||
|
iceBind := setupICEBind(t)
|
||||||
|
|
||||||
|
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer ipv6Conn.Close()
|
||||||
|
|
||||||
|
rc := receiverCreator{iceBind}
|
||||||
|
recvFn := rc.CreateReceiverFn(ipv6.NewPacketConn(ipv6Conn), ipv6Conn, false, createMsgPool())
|
||||||
|
require.NotNil(t, recvFn)
|
||||||
|
|
||||||
|
iceBind.muUDPMux.Lock()
|
||||||
|
assert.Nil(t, iceBind.ipv4Conn)
|
||||||
|
assert.NotNil(t, iceBind.ipv6Conn)
|
||||||
|
assert.NotNil(t, iceBind.udpMux)
|
||||||
|
iceBind.muUDPMux.Unlock()
|
||||||
|
|
||||||
|
mux, err := iceBind.GetICEMux()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, mux)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously verifies that we can communicate
|
||||||
|
// with peers on different address families through the same DualStackPacketConn.
|
||||||
|
func TestICEBind_SendsToIPv4AndIPv6PeersSimultaneously(t *testing.T) {
|
||||||
|
// two "remote peers" listening on different address families
|
||||||
|
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Peer.Close()
|
||||||
|
|
||||||
|
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer ipv6Peer.Close()
|
||||||
|
|
||||||
|
// our local dual-stack connection
|
||||||
|
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Local.Close()
|
||||||
|
|
||||||
|
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||||
|
defer ipv6Local.Close()
|
||||||
|
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||||
|
|
||||||
|
// send to both peers
|
||||||
|
_, err = dualStack.WriteTo([]byte("to-ipv4"), ipv4Peer.LocalAddr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = dualStack.WriteTo([]byte("to-ipv6"), ipv6Peer.LocalAddr())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify IPv4 peer got its packet from the IPv4 socket
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
_ = ipv4Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
n, addr, err := ipv4Peer.ReadFrom(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "to-ipv4", string(buf[:n]))
|
||||||
|
assert.Equal(t, ipv4Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||||
|
|
||||||
|
// verify IPv6 peer got its packet from the IPv6 socket
|
||||||
|
_ = ipv6Peer.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
n, addr, err = ipv6Peer.ReadFrom(buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "to-ipv6", string(buf[:n]))
|
||||||
|
assert.Equal(t, ipv6Local.LocalAddr().(*net.UDPAddr).Port, addr.(*net.UDPAddr).Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestICEBind_HandlesConcurrentMixedTraffic sends packets concurrently to both IPv4
|
||||||
|
// and IPv6 peers. Verifies no packets get misrouted (IPv4 peer only gets v4- packets,
|
||||||
|
// IPv6 peer only gets v6- packets). Some packet loss is acceptable for UDP.
|
||||||
|
func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
|
||||||
|
ipv4Peer := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Peer.Close()
|
||||||
|
|
||||||
|
ipv6Peer, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6loopback, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
defer ipv6Peer.Close()
|
||||||
|
|
||||||
|
ipv4Local := listenUDP(t, "udp4", "127.0.0.1:0")
|
||||||
|
defer ipv4Local.Close()
|
||||||
|
|
||||||
|
ipv6Local := listenUDP(t, "udp6", "[::1]:0")
|
||||||
|
defer ipv6Local.Close()
|
||||||
|
|
||||||
|
dualStack := NewDualStackPacketConn(ipv4Local, ipv6Local)
|
||||||
|
|
||||||
|
const packetsPerFamily = 500
|
||||||
|
|
||||||
|
ipv4Received := make(chan string, packetsPerFamily)
|
||||||
|
ipv6Received := make(chan string, packetsPerFamily)
|
||||||
|
|
||||||
|
startGate := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
n, _, err := ipv4Peer.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv4Received <- string(buf[:n])
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
buf := make([]byte, 100)
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
n, _, err := ipv6Peer.ReadFrom(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv6Received <- string(buf[:n])
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-startGate
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v4-%04d", i)), ipv4Peer.LocalAddr())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
<-startGate
|
||||||
|
for i := 0; i < packetsPerFamily; i++ {
|
||||||
|
_, _ = dualStack.WriteTo([]byte(fmt.Sprintf("v6-%04d", i)), ipv6Peer.LocalAddr())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
close(startGate)
|
||||||
|
|
||||||
|
time.AfterFunc(5*time.Second, func() {
|
||||||
|
_ = ipv4Peer.SetReadDeadline(time.Now())
|
||||||
|
_ = ipv6Peer.SetReadDeadline(time.Now())
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(ipv4Received)
|
||||||
|
close(ipv6Received)
|
||||||
|
|
||||||
|
ipv4Count := 0
|
||||||
|
for pkt := range ipv4Received {
|
||||||
|
require.True(t, len(pkt) >= 3 && pkt[:3] == "v4-", "IPv4 peer got misrouted packet: %s", pkt)
|
||||||
|
ipv4Count++
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv6Count := 0
|
||||||
|
for pkt := range ipv6Received {
|
||||||
|
require.True(t, len(pkt) >= 3 && pkt[:3] == "v6-", "IPv6 peer got misrouted packet: %s", pkt)
|
||||||
|
ipv6Count++
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, packetsPerFamily, ipv4Count)
|
||||||
|
assert.Equal(t, packetsPerFamily, ipv6Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
wantIPv4 bool
|
||||||
|
}{
|
||||||
|
{"IPv4 any", "udp4", "0.0.0.0:0", true},
|
||||||
|
{"IPv4 loopback", "udp4", "127.0.0.1:0", true},
|
||||||
|
{"IPv6 any", "udp6", "[::]:0", false},
|
||||||
|
{"IPv6 loopback", "udp6", "[::1]:0", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
addr, err := net.ResolveUDPAddr(tt.network, tt.addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP(tt.network, addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Skipf("%s not available: %v", tt.network, err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
isIPv4 := localAddr.IP.To4() != nil
|
||||||
|
assert.Equal(t, tt.wantIPv4, isIPv4)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// helpers
|
||||||
|
|
||||||
|
func setupICEBind(t *testing.T) *ICEBind {
|
||||||
|
t.Helper()
|
||||||
|
transportNet, err := stdnet.NewNet()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
address := wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
Network: netip.MustParsePrefix("100.64.0.0/10"),
|
||||||
|
}
|
||||||
|
return NewICEBind(transportNet, nil, address, 1280)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {
|
||||||
|
t.Helper()
|
||||||
|
ipv4Conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ipv6Conn, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0})
|
||||||
|
if err != nil {
|
||||||
|
ipv4Conn.Close()
|
||||||
|
t.Skipf("IPv6 not available: %v", err)
|
||||||
|
}
|
||||||
|
return ipv4Conn, ipv6Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMsgPool() *sync.Pool {
|
||||||
|
return &sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
msgs := make([]ipv6.Message, 1)
|
||||||
|
for i := range msgs {
|
||||||
|
msgs[i].Buffers = make(net.Buffers, 1)
|
||||||
|
msgs[i].OOB = make([]byte, 0, 40)
|
||||||
|
}
|
||||||
|
return &msgs
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenUDP(t *testing.T, network, addr string) *net.UDPConn {
|
||||||
|
t.Helper()
|
||||||
|
udpAddr, err := net.ResolveUDPAddr(network, addr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn, err := net.ListenUDP(network, udpAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return conn
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -286,8 +287,8 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
|
||||||
LocalIceCandidateType: pair.Local.Type().String(),
|
LocalIceCandidateType: pair.Local.Type().String(),
|
||||||
RemoteIceCandidateType: pair.Remote.Type().String(),
|
RemoteIceCandidateType: pair.Remote.Type().String(),
|
||||||
LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
|
LocalIceCandidateEndpoint: net.JoinHostPort(pair.Local.Address(), strconv.Itoa(pair.Local.Port())),
|
||||||
RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
|
RemoteIceCandidateEndpoint: net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(pair.Remote.Port())),
|
||||||
Relayed: isRelayed(pair),
|
Relayed: isRelayed(pair),
|
||||||
RelayedOnLocal: isRelayCandidate(pair.Local),
|
RelayedOnLocal: isRelayCandidate(pair.Local),
|
||||||
}
|
}
|
||||||
@@ -328,13 +329,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
|
||||||
// wait local endpoint configuration
|
// wait local endpoint configuration
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
addrString := pair.Remote.Address()
|
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pair.Remote.Address(), strconv.Itoa(remoteWgPort)))
|
||||||
parsed, err := netip.ParseAddr(addrString)
|
|
||||||
if (err == nil) && (parsed.Is6()) {
|
|
||||||
addrString = fmt.Sprintf("[%s]", addrString)
|
|
||||||
//IPv6 Literals need to be wrapped in brackets for Resolve*Addr()
|
|
||||||
}
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
w.log.Warnf("got an error while resolving the udp address, err: %s", err)
|
||||||
return
|
return
|
||||||
@@ -386,12 +381,44 @@ func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (w *WorkerICE) logSuccessfulPaths(agent *icemaker.ThreadSafeAgent) {
|
||||||
|
sessionID := w.SessionID()
|
||||||
|
stats := agent.GetCandidatePairsStats()
|
||||||
|
localCandidates, _ := agent.GetLocalCandidates()
|
||||||
|
remoteCandidates, _ := agent.GetRemoteCandidates()
|
||||||
|
|
||||||
|
localMap := make(map[string]ice.Candidate)
|
||||||
|
for _, c := range localCandidates {
|
||||||
|
localMap[c.ID()] = c
|
||||||
|
}
|
||||||
|
remoteMap := make(map[string]ice.Candidate)
|
||||||
|
for _, c := range remoteCandidates {
|
||||||
|
remoteMap[c.ID()] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stat := range stats {
|
||||||
|
if stat.State == ice.CandidatePairStateSucceeded {
|
||||||
|
local, lok := localMap[stat.LocalCandidateID]
|
||||||
|
remote, rok := remoteMap[stat.RemoteCandidateID]
|
||||||
|
if !lok || !rok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
w.log.Debugf("successful ICE path %s: [%s %s %s] <-> [%s %s %s] rtt=%.3fms",
|
||||||
|
sessionID,
|
||||||
|
local.NetworkType(), local.Type(), local.Address(),
|
||||||
|
remote.NetworkType(), remote.Type(), remote.Address(),
|
||||||
|
stat.CurrentRoundTripTime*1000)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
|
func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
|
||||||
return func(state ice.ConnectionState) {
|
return func(state ice.ConnectionState) {
|
||||||
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
|
||||||
switch state {
|
switch state {
|
||||||
case ice.ConnectionStateConnected:
|
case ice.ConnectionStateConnected:
|
||||||
w.lastKnownState = ice.ConnectionStateConnected
|
w.lastKnownState = ice.ConnectionStateConnected
|
||||||
|
w.logSuccessfulPaths(agent)
|
||||||
return
|
return
|
||||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
|
|||||||
@@ -154,9 +154,20 @@ func (s *SharedSocket) updateRouter() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns an IPv4 address using the supplied port
|
// LocalAddr returns the local address, preferring IPv4 for backward compatibility.
|
||||||
func (s *SharedSocket) LocalAddr() net.Addr {
|
func (s *SharedSocket) LocalAddr() net.Addr {
|
||||||
// todo check impact on ipv6 discovery
|
if s.conn4 != nil {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: net.IPv4zero,
|
||||||
|
Port: s.port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.conn6 != nil {
|
||||||
|
return &net.UDPAddr{
|
||||||
|
IP: net.IPv6zero,
|
||||||
|
Port: s.port,
|
||||||
|
}
|
||||||
|
}
|
||||||
return &net.UDPAddr{
|
return &net.UDPAddr{
|
||||||
IP: net.IPv4zero,
|
IP: net.IPv4zero,
|
||||||
Port: s.port,
|
Port: s.port,
|
||||||
|
|||||||
Reference in New Issue
Block a user