mirror of
https://github.com/netbirdio/netbird.git
synced 2026-03-31 06:34:14 -04:00
480 lines
12 KiB
Go
480 lines
12 KiB
Go
package stun
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pion/stun/v3"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// createTestServer creates a STUN server listening on a random port for testing.
|
|
// Returns the server, the listener connection (caller must close), and the server address.
|
|
func createTestServer(t testing.TB) (*Server, *net.UDPConn, *net.UDPAddr) {
|
|
t.Helper()
|
|
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
|
require.NoError(t, err)
|
|
server := NewServer([]*net.UDPConn{conn}, "debug")
|
|
return server, conn, conn.LocalAddr().(*net.UDPAddr)
|
|
}
|
|
|
|
// waitForServerReady polls the server with STUN binding requests until it responds.
|
|
// This avoids flaky tests on slow CI machines that relied on time.Sleep.
|
|
func waitForServerReady(t testing.TB, serverAddr *net.UDPAddr, timeout time.Duration) {
|
|
t.Helper()
|
|
deadline := time.Now().Add(timeout)
|
|
retryInterval := 10 * time.Millisecond
|
|
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
buf := make([]byte, 1500)
|
|
for time.Now().Before(deadline) {
|
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
|
require.NoError(t, err)
|
|
|
|
_, err = clientConn.Write(msg.Raw)
|
|
require.NoError(t, err)
|
|
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(retryInterval))
|
|
n, err := clientConn.Read(buf)
|
|
if err != nil {
|
|
// Timeout or other error, retry
|
|
continue
|
|
}
|
|
|
|
response := &stun.Message{Raw: buf[:n]}
|
|
if err := response.Decode(); err != nil {
|
|
continue
|
|
}
|
|
|
|
if response.Type == stun.BindingSuccess {
|
|
return // Server is ready
|
|
}
|
|
}
|
|
|
|
t.Fatalf("server did not become ready within %v", timeout)
|
|
}
|
|
|
|
func TestServer_BindingRequest(t *testing.T) {
|
|
// Start the STUN server on a random port
|
|
server, listener, serverAddr := createTestServer(t)
|
|
|
|
// Start server in background
|
|
serverErrCh := make(chan error, 1)
|
|
go func() {
|
|
serverErrCh <- server.Listen()
|
|
}()
|
|
|
|
// Wait for server to be ready
|
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
|
|
|
// Create a UDP client
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
// Build a STUN binding request
|
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
|
require.NoError(t, err)
|
|
|
|
// Send the request
|
|
_, err = clientConn.Write(msg.Raw)
|
|
require.NoError(t, err)
|
|
|
|
// Read the response
|
|
buf := make([]byte, 1500)
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
n, err := clientConn.Read(buf)
|
|
require.NoError(t, err)
|
|
|
|
// Parse the response
|
|
response := &stun.Message{Raw: buf[:n]}
|
|
err = response.Decode()
|
|
require.NoError(t, err)
|
|
|
|
// Verify it's a binding success
|
|
assert.Equal(t, stun.BindingSuccess, response.Type)
|
|
|
|
// Extract the XOR-MAPPED-ADDRESS
|
|
var xorAddr stun.XORMappedAddress
|
|
err = xorAddr.GetFrom(response)
|
|
require.NoError(t, err)
|
|
|
|
// Verify the address matches our client's local address
|
|
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
|
assert.Equal(t, clientAddr.IP.String(), xorAddr.IP.String())
|
|
assert.Equal(t, clientAddr.Port, xorAddr.Port)
|
|
|
|
// Close listener first to unblock readLoop, then shutdown
|
|
_ = listener.Close()
|
|
err = server.Shutdown()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestServer_IgnoresNonSTUNPackets(t *testing.T) {
|
|
server, listener, serverAddr := createTestServer(t)
|
|
|
|
go func() {
|
|
_ = server.Listen()
|
|
}()
|
|
|
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
|
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
// Send non-STUN data
|
|
_, err = clientConn.Write([]byte("hello world"))
|
|
require.NoError(t, err)
|
|
|
|
// Try to read response (should timeout since server ignores non-STUN)
|
|
buf := make([]byte, 1500)
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
|
_, err = clientConn.Read(buf)
|
|
assert.Error(t, err) // Should be a timeout error
|
|
|
|
// Close listener first to unblock readLoop, then shutdown
|
|
_ = listener.Close()
|
|
_ = server.Shutdown()
|
|
}
|
|
|
|
func TestServer_Shutdown(t *testing.T) {
|
|
server, listener, serverAddr := createTestServer(t)
|
|
|
|
serverDone := make(chan struct{})
|
|
go func() {
|
|
err := server.Listen()
|
|
assert.True(t, errors.Is(err, ErrServerClosed))
|
|
close(serverDone)
|
|
}()
|
|
|
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
|
|
|
// Close listener first to unblock readLoop, then shutdown
|
|
_ = listener.Close()
|
|
|
|
err := server.Shutdown()
|
|
require.NoError(t, err)
|
|
|
|
// Wait for Listen to return
|
|
select {
|
|
case <-serverDone:
|
|
// Success
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("server did not shutdown in time")
|
|
}
|
|
}
|
|
|
|
func TestServer_MultipleRequests(t *testing.T) {
|
|
server, listener, serverAddr := createTestServer(t)
|
|
|
|
go func() {
|
|
_ = server.Listen()
|
|
}()
|
|
|
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
|
|
|
// Create multiple clients and send requests
|
|
for i := 0; i < 5; i++ {
|
|
func() {
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
|
require.NoError(t, err)
|
|
|
|
_, err = clientConn.Write(msg.Raw)
|
|
require.NoError(t, err)
|
|
|
|
buf := make([]byte, 1500)
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
n, err := clientConn.Read(buf)
|
|
require.NoError(t, err)
|
|
|
|
response := &stun.Message{Raw: buf[:n]}
|
|
err = response.Decode()
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, stun.BindingSuccess, response.Type)
|
|
}()
|
|
}
|
|
|
|
// Close listener first to unblock readLoop, then shutdown
|
|
_ = listener.Close()
|
|
_ = server.Shutdown()
|
|
}
|
|
|
|
func TestServer_ConcurrentClients(t *testing.T) {
|
|
numClients := 100
|
|
requestsPerClient := 5
|
|
maxStartDelay := 100 * time.Millisecond // Random delay before client starts
|
|
maxRequestDelay := 500 * time.Millisecond // Random delay between requests
|
|
|
|
// Remote server to test against via env var STUN_TEST_SERVER
|
|
// Example: STUN_TEST_SERVER=example.netbird.io:3478 go test -v ./stun/... -run ConcurrentClients
|
|
remoteServer := os.Getenv("STUN_TEST_SERVER")
|
|
|
|
var serverAddr *net.UDPAddr
|
|
var server *Server
|
|
var listener *net.UDPConn
|
|
|
|
if remoteServer != "" {
|
|
// Use remote server
|
|
var err error
|
|
serverAddr, err = net.ResolveUDPAddr("udp", remoteServer)
|
|
require.NoError(t, err)
|
|
t.Logf("Testing against remote server: %s", remoteServer)
|
|
} else {
|
|
// Start local server
|
|
server, listener, serverAddr = createTestServer(t)
|
|
go func() {
|
|
_ = server.Listen()
|
|
}()
|
|
waitForServerReady(t, serverAddr, 2*time.Second)
|
|
t.Logf("Testing against local server: %s", serverAddr)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errorz := make(chan error, numClients*requestsPerClient)
|
|
successCount := make(chan int, numClients)
|
|
|
|
startTime := time.Now()
|
|
|
|
for i := 0; i < numClients; i++ {
|
|
wg.Add(1)
|
|
go func(clientID int) {
|
|
defer wg.Done()
|
|
|
|
// Random delay before starting
|
|
time.Sleep(time.Duration(rand.Int63n(int64(maxStartDelay))))
|
|
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
if err != nil {
|
|
errorz <- fmt.Errorf("client %d: failed to dial: %w", clientID, err)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
success := 0
|
|
for j := 0; j < requestsPerClient; j++ {
|
|
// Random delay between requests
|
|
if j > 0 {
|
|
time.Sleep(time.Duration(rand.Int63n(int64(maxRequestDelay))))
|
|
}
|
|
|
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
|
if err != nil {
|
|
errorz <- fmt.Errorf("client %d: failed to build request: %w", clientID, err)
|
|
continue
|
|
}
|
|
|
|
_, err = clientConn.Write(msg.Raw)
|
|
if err != nil {
|
|
errorz <- fmt.Errorf("client %d: failed to write: %w", clientID, err)
|
|
continue
|
|
}
|
|
|
|
buf := make([]byte, 1500)
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
n, err := clientConn.Read(buf)
|
|
if err != nil {
|
|
errorz <- fmt.Errorf("client %d: failed to read: %w", clientID, err)
|
|
continue
|
|
}
|
|
|
|
response := &stun.Message{Raw: buf[:n]}
|
|
if err := response.Decode(); err != nil {
|
|
errorz <- fmt.Errorf("client %d: failed to decode: %w", clientID, err)
|
|
continue
|
|
}
|
|
|
|
if response.Type != stun.BindingSuccess {
|
|
errorz <- fmt.Errorf("client %d: unexpected response type: %s", clientID, response.Type)
|
|
continue
|
|
}
|
|
|
|
success++
|
|
}
|
|
successCount <- success
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errorz)
|
|
close(successCount)
|
|
|
|
elapsed := time.Since(startTime)
|
|
|
|
totalSuccess := 0
|
|
for count := range successCount {
|
|
totalSuccess += count
|
|
}
|
|
|
|
var errs []error
|
|
for err := range errorz {
|
|
errs = append(errs, err)
|
|
}
|
|
|
|
totalRequests := numClients * requestsPerClient
|
|
t.Logf("Completed %d/%d requests in %v (%.2f req/s)",
|
|
totalSuccess, totalRequests, elapsed,
|
|
float64(totalSuccess)/elapsed.Seconds())
|
|
|
|
if len(errs) > 0 {
|
|
t.Logf("Errors (%d):", len(errs))
|
|
for i, err := range errs {
|
|
if i < 10 { // Only show first 10 errors
|
|
t.Logf(" - %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Require at least 95% success rate
|
|
successRate := float64(totalSuccess) / float64(totalRequests)
|
|
require.GreaterOrEqual(t, successRate, 0.95, "success rate too low: %.2f%%", successRate*100)
|
|
|
|
// Cleanup local server if used
|
|
if server != nil {
|
|
// Close listener first to unblock readLoop, then shutdown
|
|
_ = listener.Close()
|
|
_ = server.Shutdown()
|
|
}
|
|
}
|
|
|
|
func TestServer_MultiplePorts(t *testing.T) {
|
|
// Create listeners on two random ports
|
|
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
|
require.NoError(t, err)
|
|
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
|
require.NoError(t, err)
|
|
|
|
addr1 := conn1.LocalAddr().(*net.UDPAddr)
|
|
addr2 := conn2.LocalAddr().(*net.UDPAddr)
|
|
|
|
server := NewServer([]*net.UDPConn{conn1, conn2}, "debug")
|
|
|
|
go func() {
|
|
_ = server.Listen()
|
|
}()
|
|
|
|
// Wait for server to be ready (checking first port is sufficient)
|
|
waitForServerReady(t, addr1, 2*time.Second)
|
|
|
|
// Test requests on both ports
|
|
for _, serverAddr := range []*net.UDPAddr{addr1, addr2} {
|
|
func() {
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
require.NoError(t, err)
|
|
defer clientConn.Close()
|
|
|
|
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
|
require.NoError(t, err)
|
|
|
|
_, err = clientConn.Write(msg.Raw)
|
|
require.NoError(t, err)
|
|
|
|
buf := make([]byte, 1500)
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
n, err := clientConn.Read(buf)
|
|
require.NoError(t, err)
|
|
|
|
response := &stun.Message{Raw: buf[:n]}
|
|
err = response.Decode()
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, stun.BindingSuccess, response.Type)
|
|
|
|
var xorAddr stun.XORMappedAddress
|
|
err = xorAddr.GetFrom(response)
|
|
require.NoError(t, err)
|
|
|
|
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
|
assert.Equal(t, clientAddr.Port, xorAddr.Port)
|
|
}()
|
|
}
|
|
|
|
// Close listeners first to unblock readLoops, then shutdown
|
|
_ = conn1.Close()
|
|
_ = conn2.Close()
|
|
_ = server.Shutdown()
|
|
}
|
|
|
|
// BenchmarkSTUNServer benchmarks the STUN server with concurrent clients
|
|
func BenchmarkSTUNServer(b *testing.B) {
|
|
server, listener, serverAddr := createTestServer(b)
|
|
|
|
go func() {
|
|
_ = server.Listen()
|
|
}()
|
|
|
|
waitForServerReady(b, serverAddr, 2*time.Second)
|
|
|
|
// Capture first error atomically - b.Fatal cannot be called from worker goroutines
|
|
var firstErr atomic.Pointer[error]
|
|
setErr := func(err error) {
|
|
firstErr.CompareAndSwap(nil, &err)
|
|
}
|
|
|
|
b.ResetTimer()
|
|
b.RunParallel(func(pb *testing.PB) {
|
|
// Stop work if an error has occurred
|
|
if firstErr.Load() != nil {
|
|
return
|
|
}
|
|
|
|
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
|
if err != nil {
|
|
setErr(err)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
buf := make([]byte, 1500)
|
|
|
|
for pb.Next() {
|
|
if firstErr.Load() != nil {
|
|
return
|
|
}
|
|
|
|
msg, _ := stun.Build(stun.TransactionID, stun.BindingRequest)
|
|
_, _ = clientConn.Write(msg.Raw)
|
|
|
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
n, err := clientConn.Read(buf)
|
|
if err != nil {
|
|
setErr(err)
|
|
return
|
|
}
|
|
|
|
response := &stun.Message{Raw: buf[:n]}
|
|
if err := response.Decode(); err != nil {
|
|
setErr(err)
|
|
return
|
|
}
|
|
}
|
|
})
|
|
|
|
b.StopTimer()
|
|
|
|
// Fail after RunParallel completes
|
|
if errPtr := firstErr.Load(); errPtr != nil {
|
|
b.Fatal(*errPtr)
|
|
}
|
|
|
|
// Close listener first to unblock readLoop, then shutdown
|
|
_ = listener.Close()
|
|
_ = server.Shutdown()
|
|
}
|