mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-06 09:24:07 -04:00
Compare commits
20 Commits
cli-ws-pro
...
loadtest-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4787e28ae3 | ||
|
|
f9a71e98c7 | ||
|
|
26d1a9b68a | ||
|
|
3d983ddc60 | ||
|
|
9217df05eb | ||
|
|
213043fe7a | ||
|
|
e67829d1d7 | ||
|
|
2288664fe7 | ||
|
|
d1153b5b5d | ||
|
|
6d26c9d1ba | ||
|
|
d35a845dbd | ||
|
|
4e03f708a4 | ||
|
|
654aa9581d | ||
|
|
9021bb512b | ||
|
|
768332820e | ||
|
|
229c65ffa1 | ||
|
|
4d33567888 | ||
|
|
88467883fc | ||
|
|
954f40991f | ||
|
|
34341d95a9 |
@@ -29,7 +29,8 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
if tlsEnabled {
|
||||
// for js, the outer websocket layer takes care of tls
|
||||
if tlsEnabled && runtime.GOOS != "js" {
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
@@ -37,9 +38,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}
|
||||
|
||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
||||
InsecureSkipVerify: runtime.GOOS == "js",
|
||||
RootCAs: certPool,
|
||||
RootCAs: certPool,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -73,6 +73,44 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get the existing peer to preserve its allowed IPs
|
||||
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get peer: %w", err)
|
||||
}
|
||||
|
||||
removePeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
||||
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
||||
}
|
||||
|
||||
//Re-add the peer without the endpoint but same AllowedIPs
|
||||
reAddPeerCfg := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
AllowedIPs: existingPeer.AllowedIPs,
|
||||
ReplaceAllowedIPs: true,
|
||||
}
|
||||
|
||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
||||
return fmt.Errorf(
|
||||
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
||||
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -106,6 +106,67 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse peer key: %w", err)
|
||||
}
|
||||
|
||||
ipcStr, err := c.device.IpcGet()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get IPC config: %w", err)
|
||||
}
|
||||
|
||||
// Parse current status to get allowed IPs for the peer
|
||||
stats, err := parseStatus(c.deviceName, ipcStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse IPC config: %w", err)
|
||||
}
|
||||
|
||||
var allowedIPs []net.IPNet
|
||||
found := false
|
||||
for _, peer := range stats.Peers {
|
||||
if peer.PublicKey == peerKey {
|
||||
allowedIPs = peer.AllowedIPs
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("peer %s not found", peerKey)
|
||||
}
|
||||
|
||||
// remove the peer from the WireGuard configuration
|
||||
peer := wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
Remove: true,
|
||||
}
|
||||
|
||||
config := wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
||||
}
|
||||
|
||||
// Build the peer config
|
||||
peer = wgtypes.PeerConfig{
|
||||
PublicKey: peerKeyParsed,
|
||||
ReplaceAllowedIPs: true,
|
||||
AllowedIPs: allowedIPs,
|
||||
}
|
||||
|
||||
config = wgtypes.Config{
|
||||
Peers: []wgtypes.PeerConfig{peer},
|
||||
}
|
||||
|
||||
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
||||
return fmt.Errorf("remove endpoint address: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||
if err != nil {
|
||||
|
||||
@@ -21,4 +21,5 @@ type WGConfigurer interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
RemoveEndpointAddress(peerKey string) error
|
||||
}
|
||||
|
||||
@@ -148,6 +148,17 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||
}
|
||||
|
||||
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.configurer == nil {
|
||||
return ErrIfaceNotFound
|
||||
}
|
||||
|
||||
log.Debugf("Removing endpoint address: %s", peerKey)
|
||||
return w.configurer.RemoveEndpointAddress(peerKey)
|
||||
}
|
||||
|
||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||
w.mu.Lock()
|
||||
|
||||
78
client/internal/dnsfwd/cache.go
Normal file
78
client/internal/dnsfwd/cache.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
ip4Addrs []netip.Addr
|
||||
ip6Addrs []netip.Addr
|
||||
}
|
||||
|
||||
func newCache() *cache {
|
||||
return &cache{
|
||||
records: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.records[normalizeDomain(domain)]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
return slices.Clone(entry.ip4Addrs), true
|
||||
case dns.TypeAAAA:
|
||||
return slices.Clone(entry.ip6Addrs), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
norm := normalizeDomain(domain)
|
||||
entry, exists := c.records[norm]
|
||||
if !exists {
|
||||
entry = &cacheEntry{}
|
||||
c.records[norm] = entry
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
entry.ip4Addrs = slices.Clone(addrs)
|
||||
case dns.TypeAAAA:
|
||||
entry.ip6Addrs = slices.Clone(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// unset removes cached entries for the given domain and request type.
|
||||
func (c *cache) unset(domain string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.records, normalizeDomain(domain))
|
||||
}
|
||||
|
||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||
// lowercase and fully-qualified (with trailing dot).
|
||||
func normalizeDomain(domain string) string {
|
||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||
return dns.Fqdn(strings.ToLower(domain))
|
||||
}
|
||||
86
client/internal/dnsfwd/cache_test.go
Normal file
86
client/internal/dnsfwd/cache_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||
t.Helper()
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse addr %s: %v", s, err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestCacheNormalization(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
// Mixed case, without trailing dot
|
||||
domainInput := "ExAmPlE.CoM"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||
|
||||
// Lookup with lower, with trailing dot
|
||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Lookup with different casing again
|
||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSeparateTypes(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
domain := "test.local"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||
|
||||
c.set(domain, 1 /* A */, ipv4)
|
||||
c.set(domain, 28 /* AAAA */, ipv6)
|
||||
|
||||
got4, ok4 := c.get(domain, 1)
|
||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||
}
|
||||
|
||||
got6, ok6 := c.get(domain, 28)
|
||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||
c := newCache()
|
||||
domain := "clone.test"
|
||||
|
||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||
c.set(domain, 1, src)
|
||||
|
||||
// Mutate source slice; cache should be unaffected
|
||||
src[0] = mustAddr(t, "9.9.9.9")
|
||||
|
||||
got, ok := c.get(domain, 1)
|
||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Mutate returned slice; internal cache should remain unchanged
|
||||
got[0] = mustAddr(t, "4.4.4.4")
|
||||
got2, ok2 := c.get(domain, 1)
|
||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := newCache()
|
||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ type DNSForwarder struct {
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// remove cache entries for domains that no longer appear
|
||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
}
|
||||
|
||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||
// in the old list but not present in the new list.
|
||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||
if f.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newSet := make(map[string]struct{}, len(newEntries))
|
||||
for _, e := range newEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, e := range oldEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
pattern := e.Domain.PunycodeString()
|
||||
if _, ok := newSet[pattern]; !ok {
|
||||
f.cache.unset(pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
// Ensures that when the first query succeeds and populates the cache,
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First call resolves successfully and populates cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
// Second call fails upstream; forwarder should serve from cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
// First query: populate cache
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("9.8.7.6")
|
||||
|
||||
// Initial resolution with mixed case to populate cache
|
||||
mixedQuery := "ExAmPlE.CoM"
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
|
||||
@@ -40,7 +40,6 @@ type Manager struct {
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
port uint16
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
@@ -49,11 +48,16 @@ func ListenPort() uint16 {
|
||||
return listenPort
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||
func SetListenPort(port uint16) {
|
||||
listenPortMu.Lock()
|
||||
listenPort = port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.port > 0 {
|
||||
listenPortMu.Lock()
|
||||
listenPort = m.port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
|
||||
@@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder(
|
||||
return
|
||||
}
|
||||
|
||||
if forwarderPort > 0 {
|
||||
dnsfwd.SetListenPort(forwarderPort)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder(
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
@@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
|
||||
@@ -105,6 +105,10 @@ type MockWGIface struct {
|
||||
LastActivitiesFunc func() map[string]monotime.Time
|
||||
}
|
||||
|
||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type wgIfaceBase interface {
|
||||
UpdateAddr(newAddr string) error
|
||||
GetProxy() wgproxy.Proxy
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
RemoveEndpointAddress(key string) error
|
||||
RemovePeer(peerKey string) error
|
||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
|
||||
@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||
@@ -430,6 +430,9 @@ func (conn *Conn) onICEStateDisconnected() {
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
conn.currentConnPriority = conntype.None
|
||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||
@@ -523,6 +526,9 @@ func (conn *Conn) onRelayDisconnected() {
|
||||
if conn.currentConnPriority == conntype.Relay {
|
||||
conn.Log.Debugf("clean up WireGuard config")
|
||||
conn.currentConnPriority = conntype.None
|
||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if conn.wgProxyRelay != nil {
|
||||
|
||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOffeChan := make(chan struct{})
|
||||
onNewOfferChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteOffer(OfferAnswer{
|
||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOffeChan:
|
||||
case <-onNewOfferChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOffeChan := make(chan struct{})
|
||||
onNewOfferChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteAnswer(OfferAnswer{
|
||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOffeChan:
|
||||
case <-onNewOfferChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
|
||||
20
client/internal/peer/guard/env.go
Normal file
20
client/internal/peer/guard/env.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
||||
)
|
||||
|
||||
func GetICEMonitorPeriod() time.Duration {
|
||||
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
||||
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
return defaultCandidatesMonitorPeriod
|
||||
}
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ICEMonitor struct {
|
||||
@@ -25,16 +25,19 @@ type ICEMonitor struct {
|
||||
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig icemaker.Config
|
||||
tickerPeriod time.Duration
|
||||
|
||||
currentCandidatesAddress []string
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
||||
log.Debugf("prepare ICE monitor with period: %s", period)
|
||||
cm := &ICEMonitor{
|
||||
ReconnectCh: make(chan struct{}, 1),
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
iceConfig: config,
|
||||
tickerPeriod: period,
|
||||
}
|
||||
return cm
|
||||
}
|
||||
@@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
return
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
// Initial check to populate the candidates for later comparison
|
||||
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
||||
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cm.tickerPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
@@ -44,13 +44,19 @@ type OfferAnswer struct {
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
onNewOfferListeners []*OfferListener
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||
// and it will only keep the latest message if multiple offers are received in a short time
|
||||
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||
// and also to avoid processing old offers if multiple offers are received in a short time
|
||||
// the listener will always process the latest offer
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
l := NewOfferListener(offer)
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.relayListener = NewAsyncOfferListener(offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.iceListener = offer
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if err := h.sendAnswer(); err != nil {
|
||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||
continue
|
||||
}
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
h.log.Infof("stop listening for remote offers and answers")
|
||||
|
||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
||||
return oa.SessionID.String()
|
||||
}
|
||||
|
||||
type OfferListener struct {
|
||||
type AsyncOfferListener struct {
|
||||
fn callbackFunc
|
||||
running bool
|
||||
latest *OfferAnswer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||
return &OfferListener{
|
||||
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
||||
return &AsyncOfferListener{
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
||||
runChan <- struct{}{}
|
||||
}
|
||||
|
||||
hl := NewOfferListener(longRunningFn)
|
||||
hl := NewAsyncOfferListener(longRunningFn)
|
||||
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
|
||||
@@ -18,4 +18,5 @@ type WGIface interface {
|
||||
GetStats() (map[string]configurer.WGStats, error)
|
||||
GetProxy() wgproxy.Proxy
|
||||
Address() wgaddr.Address
|
||||
RemoveEndpointAddress(key string) error
|
||||
}
|
||||
|
||||
@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
if w.agentConnecting {
|
||||
w.log.Debugf("agent connection is in progress, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if w.agent != nil {
|
||||
if w.agent != nil || w.agentConnecting {
|
||||
// backward compatibility with old clients that do not send session ID
|
||||
if remoteOfferAnswer.SessionID == nil {
|
||||
w.log.Debugf("agent already exists, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
w.log.Debugf("recreate ICE agent")
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
||||
}
|
||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.agent = agent
|
||||
w.agentDialerCancel = dialerCancel
|
||||
w.agentConnecting = true
|
||||
w.muxAgent.Unlock()
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
} else {
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
|
||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||
}
|
||||
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.muxAgent.Lock()
|
||||
w.agentConnecting = false
|
||||
w.lastSuccess = time.Now()
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
// todo: the potential problem is a race between the onConnectionStateChange
|
||||
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1354,7 +1354,13 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
||||
func (s *serviceClient) showLoginURL() {
|
||||
// It also starts a background goroutine that periodically checks if the client is already connected
|
||||
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
|
||||
// also cancelled when the window is closed.
|
||||
func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
|
||||
// create a cancellable context for the background check goroutine
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
||||
|
||||
@@ -1363,6 +1369,8 @@ func (s *serviceClient) showLoginURL() {
|
||||
s.wLoginURL.Resize(fyne.NewSize(400, 200))
|
||||
s.wLoginURL.SetIcon(resIcon)
|
||||
}
|
||||
// ensure goroutine is cancelled when the window is closed
|
||||
s.wLoginURL.SetOnClosed(func() { cancel() })
|
||||
// add a description label
|
||||
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
|
||||
|
||||
@@ -1443,7 +1451,39 @@ func (s *serviceClient) showLoginURL() {
|
||||
)
|
||||
s.wLoginURL.SetContent(container.NewCenter(content))
|
||||
|
||||
// start a goroutine to check connection status and close the window if connected
|
||||
go func() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
conn, err := s.getSrvClient(failFastTimeout)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
if s.wLoginURL != nil {
|
||||
s.wLoginURL.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
s.wLoginURL.Show()
|
||||
|
||||
// return cancel func so callers can stop the background goroutine if desired
|
||||
return cancel
|
||||
}
|
||||
|
||||
func openURL(url string) error {
|
||||
|
||||
@@ -47,7 +47,7 @@ services:
|
||||
- traefik.enable=true
|
||||
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
|
||||
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
|
||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||
|
||||
@@ -621,7 +621,7 @@ renderCaddyfile() {
|
||||
# relay
|
||||
reverse_proxy /relay* relay:80
|
||||
# Signal
|
||||
reverse_proxy /ws-proxy/signal* signal:10000
|
||||
reverse_proxy /ws-proxy/signal* signal:80
|
||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||
# Management
|
||||
reverse_proxy /api/* management:80
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -252,7 +251,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
|
||||
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
switch {
|
||||
|
||||
@@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
@@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers && userID != activity.SystemInitiator {
|
||||
if userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||
}
|
||||
|
||||
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
@@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
if updateAccountPeers {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
@@ -1527,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
||||
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
|
||||
}
|
||||
|
||||
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
||||
// Returns a slice of functions to save events after successful peer deletion.
|
||||
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
||||
|
||||
@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg) //
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"net/http"
|
||||
// nolint:gosec
|
||||
_ "net/http/pprof"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
@@ -63,10 +62,10 @@ var (
|
||||
Use: "run",
|
||||
Short: "start NetBird Signal Server daemon",
|
||||
SilenceUsage: true,
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := util.InitLog(logLevel, logFile)
|
||||
if err != nil {
|
||||
log.Fatalf("failed initializing log %v", err)
|
||||
return fmt.Errorf("failed initializing log: %w", err)
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
@@ -87,6 +86,8 @@ var (
|
||||
signalPort = 80
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
flag.Parse()
|
||||
@@ -254,7 +255,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
|
||||
}
|
||||
|
||||
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
|
||||
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter))
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
|
||||
381
signal/loadtest/README.md
Normal file
381
signal/loadtest/README.md
Normal file
@@ -0,0 +1,381 @@
|
||||
# Signal Server Load Test
|
||||
|
||||
Load testing tool for the NetBird signal server.
|
||||
|
||||
## Features
|
||||
|
||||
- **Rate-based peer pair creation**: Spawn peer pairs at configurable rates (e.g., 10, 20 pairs/sec)
|
||||
- **Two exchange modes**:
|
||||
- **Single message**: Each pair exchanges one message for validation
|
||||
- **Continuous exchange**: Pairs continuously exchange messages for a specified duration (e.g., 30 seconds, 10 minutes)
|
||||
- **TLS/HTTPS support**: Connect to TLS-enabled signal servers with optional certificate verification
|
||||
- **Automatic reconnection**: Optional automatic reconnection with exponential backoff on connection loss
|
||||
- **Configurable message interval**: Control message send rate in continuous mode
|
||||
- **Message exchange validation**: Validates encrypted body size > 0
|
||||
- **Comprehensive metrics**: Tracks throughput, success/failure rates, latency statistics, and reconnection counts
|
||||
- **Local server testing**: Tests include embedded signal server for easy development
|
||||
- **Worker pool pattern**: Efficient concurrent execution
|
||||
- **Graceful shutdown**: Context-based cancellation
|
||||
|
||||
## Usage
|
||||
|
||||
### Standalone Binary
|
||||
|
||||
Build and run the load test as a standalone binary:
|
||||
|
||||
```bash
|
||||
# Build the binary
|
||||
cd signal/loadtest/cmd/signal-loadtest
|
||||
go build -o signal-loadtest
|
||||
|
||||
# Single message exchange
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 10 \
|
||||
-total-pairs 100 \
|
||||
-message-size 100
|
||||
|
||||
# Continuous exchange for 30 seconds
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 10 \
|
||||
-total-pairs 20 \
|
||||
-message-size 200 \
|
||||
-exchange-duration 30s \
|
||||
-message-interval 200ms
|
||||
|
||||
# Long-running test (10 minutes)
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 20 \
|
||||
-total-pairs 50 \
|
||||
-message-size 500 \
|
||||
-exchange-duration 10m \
|
||||
-message-interval 100ms \
|
||||
-test-duration 15m \
|
||||
-log-level debug
|
||||
|
||||
# TLS server with valid certificate
|
||||
./signal-loadtest \
|
||||
-server https://signal.example.com:443 \
|
||||
-pairs-per-sec 10 \
|
||||
-total-pairs 50 \
|
||||
-message-size 100
|
||||
|
||||
# TLS server with self-signed certificate
|
||||
./signal-loadtest \
|
||||
-server https://localhost:443 \
|
||||
-pairs-per-sec 5 \
|
||||
-total-pairs 10 \
|
||||
-insecure-skip-verify \
|
||||
-log-level debug
|
||||
|
||||
# High load test with custom worker pool
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 100 \
|
||||
-total-pairs 1000 \
|
||||
-worker-pool-size 500 \
|
||||
-channel-buffer-size 1000 \
|
||||
-exchange-duration 60s \
|
||||
-log-level info
|
||||
|
||||
# Progress reporting - report every 5000 messages
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 50 \
|
||||
-total-pairs 100 \
|
||||
-exchange-duration 5m \
|
||||
-report-interval 5000 \
|
||||
-log-level info
|
||||
|
||||
# With automatic reconnection
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 10 \
|
||||
-total-pairs 50 \
|
||||
-exchange-duration 5m \
|
||||
-enable-reconnect \
|
||||
-initial-retry-delay 100ms \
|
||||
-max-reconnect-delay 30s \
|
||||
-log-level debug
|
||||
|
||||
# Show help
|
||||
./signal-loadtest -h
|
||||
```
|
||||
|
||||
**Graceful Shutdown:**
|
||||
|
||||
The load test supports graceful shutdown via Ctrl+C (SIGINT/SIGTERM):
|
||||
- Press Ctrl+C to interrupt the test at any time
|
||||
- All active clients will be closed gracefully
|
||||
- A final aggregated report will be printed showing metrics up to the point of interruption
|
||||
- Shutdown timeout: 5 seconds (after which the process will force exit)
|
||||
|
||||
**Available Flags:**
|
||||
- `-server`: Signal server URL (http:// or https://) (default: `http://localhost:10000`)
|
||||
- `-pairs-per-sec`: Peer pairs created per second (default: 10)
|
||||
- `-total-pairs`: Total number of peer pairs (default: 100)
|
||||
- `-message-size`: Message size in bytes (default: 100)
|
||||
- `-test-duration`: Maximum test duration, 0 = unlimited (default: 0)
|
||||
- `-exchange-duration`: Continuous exchange duration per pair, 0 = single message (default: 0)
|
||||
- `-message-interval`: Interval between messages in continuous mode (default: 100ms)
|
||||
- `-worker-pool-size`: Number of concurrent workers, 0 = auto (pairs-per-sec × 2) (default: 0)
|
||||
- `-channel-buffer-size`: Work queue buffer size, 0 = auto (pairs-per-sec × 4) (default: 0)
|
||||
- `-report-interval`: Report progress every N messages, 0 = no periodic reports (default: 10000)
|
||||
- `-enable-reconnect`: Enable automatic reconnection on connection loss (default: false)
|
||||
- `-initial-retry-delay`: Initial delay before first reconnection attempt (default: 100ms)
|
||||
- `-max-reconnect-delay`: Maximum delay between reconnection attempts (default: 30s)
|
||||
- `-insecure-skip-verify`: Skip TLS certificate verification for self-signed certificates (default: false)
|
||||
- `-log-level`: Log level: trace, debug, info, warn, error (default: info)
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests (includes load tests)
|
||||
go test -v -timeout 2m
|
||||
|
||||
# Run specific single-message load tests
|
||||
go test -v -run TestLoadTest_10PairsPerSecond -timeout 40s
|
||||
go test -v -run TestLoadTest_20PairsPerSecond -timeout 40s
|
||||
go test -v -run TestLoadTest_SmallBurst -timeout 30s
|
||||
|
||||
# Run continuous exchange tests
|
||||
go test -v -run TestLoadTest_ContinuousExchange_ShortBurst -timeout 30s
|
||||
go test -v -run TestLoadTest_ContinuousExchange_30Seconds -timeout 2m
|
||||
go test -v -run TestLoadTest_ContinuousExchange_10Minutes -timeout 15m
|
||||
|
||||
# Skip long-running tests in quick runs
|
||||
go test -short
|
||||
```
|
||||
|
||||
### Programmatic Usage
|
||||
|
||||
#### Single Message Exchange
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/signal/loadtest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config := loadtest.LoadTestConfig{
|
||||
ServerURL: "http://localhost:10000",
|
||||
PairsPerSecond: 10,
|
||||
TotalPairs: 100,
|
||||
MessageSize: 100,
|
||||
TestDuration: 30 * time.Second,
|
||||
}
|
||||
|
||||
lt := loadtest.NewLoadTest(config)
|
||||
if err := lt.Run(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
metrics := lt.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
}
|
||||
```
|
||||
|
||||
#### Continuous Message Exchange
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/signal/loadtest"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config := loadtest.LoadTestConfig{
|
||||
ServerURL: "http://localhost:10000",
|
||||
PairsPerSecond: 10,
|
||||
TotalPairs: 20,
|
||||
MessageSize: 200,
|
||||
ExchangeDuration: 10 * time.Minute, // Each pair exchanges messages for 10 minutes
|
||||
MessageInterval: 200 * time.Millisecond, // Send message every 200ms
|
||||
TestDuration: 15 * time.Minute, // Overall test timeout
|
||||
}
|
||||
|
||||
lt := loadtest.NewLoadTest(config)
|
||||
if err := lt.Run(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
metrics := lt.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
- **ServerURL**: Signal server URL (e.g., `http://localhost:10000` or `https://signal.example.com:443`)
|
||||
- **PairsPerSecond**: Rate at which peer pairs are created (e.g., 10, 20)
|
||||
- **TotalPairs**: Total number of peer pairs to create
|
||||
- **MessageSize**: Size of test message payload in bytes
|
||||
- **TestDuration**: Maximum test duration (optional, 0 = no limit)
|
||||
- **ExchangeDuration**: Duration for continuous message exchange per pair (0 = single message)
|
||||
- **MessageInterval**: Interval between messages in continuous mode (default: 100ms)
|
||||
- **WorkerPoolSize**: Number of concurrent worker goroutines (0 = auto: pairs-per-sec × 2)
|
||||
- **ChannelBufferSize**: Work queue buffer size (0 = auto: pairs-per-sec × 4)
|
||||
- **ReportInterval**: Report progress every N messages (0 = no periodic reports, default: 10000)
|
||||
- **EnableReconnect**: Enable automatic reconnection on connection loss (default: false)
|
||||
- **InitialRetryDelay**: Initial delay before first reconnection attempt (default: 100ms)
|
||||
- **MaxReconnectDelay**: Maximum delay between reconnection attempts (default: 30s)
|
||||
- **InsecureSkipVerify**: Skip TLS certificate verification (for self-signed certificates)
|
||||
- **RampUpDuration**: Gradual ramp-up period (not yet implemented)
|
||||
|
||||
### Reconnection Handling
|
||||
|
||||
The load test supports automatic reconnection on connection loss:
|
||||
|
||||
- **Disabled by default**: Connections will fail on any network interruption
|
||||
- **When enabled**: Clients automatically reconnect with exponential backoff
|
||||
- **Exponential backoff**: Starts at `InitialRetryDelay`, doubles on each failure, caps at `MaxReconnectDelay`
|
||||
- **Transparent reconnection**: Message exchange continues after successful reconnection
|
||||
- **Metrics tracking**: Total reconnection count is reported
|
||||
|
||||
**Use cases:**
|
||||
- Testing resilience to network interruptions
|
||||
- Validating server restart behavior
|
||||
- Simulating flaky network conditions
|
||||
- Long-running stability tests
|
||||
|
||||
**Example with reconnection:**
|
||||
```go
|
||||
config := loadtest.LoadTestConfig{
|
||||
ServerURL: "http://localhost:10000",
|
||||
PairsPerSecond: 10,
|
||||
TotalPairs: 20,
|
||||
ExchangeDuration: 10 * time.Minute,
|
||||
EnableReconnect: true,
|
||||
InitialRetryDelay: 100 * time.Millisecond,
|
||||
MaxReconnectDelay: 30 * time.Second,
|
||||
}
|
||||
```
|
||||
|
||||
### Performance Tuning
|
||||
|
||||
When running high-load tests, you may need to adjust the worker pool and buffer sizes:
|
||||
|
||||
- **Default sizing**: Auto-configured based on `PairsPerSecond`
|
||||
- Worker pool: `PairsPerSecond × 2`
|
||||
- Channel buffer: `PairsPerSecond × 4`
|
||||
- **For continuous exchange**: Increase worker pool size (e.g., `PairsPerSecond × 5`)
|
||||
- **For high pair rates** (>50/sec): Increase both worker pool and buffer proportionally
|
||||
- **Signs you need more workers**: Log warnings about "Worker pool saturated"
|
||||
|
||||
Example for 100 pairs/sec with continuous exchange:
|
||||
```go
|
||||
config := LoadTestConfig{
|
||||
PairsPerSecond: 100,
|
||||
WorkerPoolSize: 500, // 5x pairs/sec
|
||||
ChannelBufferSize: 1000, // 10x pairs/sec
|
||||
}
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
The load test collects and reports:
|
||||
|
||||
- **Total Pairs Sent**: Number of peer pairs attempted
|
||||
- **Successful Exchanges**: Completed message exchanges
|
||||
- **Failed Exchanges**: Failed message exchanges
|
||||
- **Total Messages Exchanged**: Count of successfully exchanged messages
|
||||
- **Total Errors**: Cumulative error count
|
||||
- **Total Reconnections**: Number of automatic reconnections (if enabled)
|
||||
- **Throughput**: Pairs per second (actual)
|
||||
- **Latency Statistics**: Min, Max, Avg message exchange latency
|
||||
|
||||
## Graceful Shutdown Example
|
||||
|
||||
You can interrupt a long-running test at any time with Ctrl+C:
|
||||
|
||||
```
|
||||
./signal-loadtest -server http://localhost:10000 -pairs-per-sec 10 -total-pairs 100 -exchange-duration 10m
|
||||
|
||||
# Press Ctrl+C after some time...
|
||||
^C
|
||||
WARN[0045]
|
||||
Received interrupt signal, shutting down gracefully...
|
||||
|
||||
=== Load Test Report ===
|
||||
Test Duration: 45.234s
|
||||
Total Pairs Sent: 75
|
||||
Successful Exchanges: 75
|
||||
Failed Exchanges: 0
|
||||
Total Messages Exchanged: 22500
|
||||
Total Errors: 0
|
||||
Throughput: 1.66 pairs/sec
|
||||
...
|
||||
========================
|
||||
```
|
||||
|
||||
## Test Results
|
||||
|
||||
Example output from a 20 pairs/sec test:
|
||||
|
||||
```
|
||||
=== Load Test Report ===
|
||||
Test Duration: 5.055249917s
|
||||
Total Pairs Sent: 100
|
||||
Successful Exchanges: 100
|
||||
Failed Exchanges: 0
|
||||
Total Messages Exchanged: 100
|
||||
Total Errors: 0
|
||||
Throughput: 19.78 pairs/sec
|
||||
|
||||
Latency Statistics:
|
||||
Min: 170.375µs
|
||||
Max: 5.176916ms
|
||||
Avg: 441.566µs
|
||||
========================
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Client (`client.go`)
|
||||
- Manages gRPC connection to signal server
|
||||
- Establishes bidirectional stream for receiving messages
|
||||
- Sends messages via `Send` RPC method
|
||||
- Handles message reception asynchronously
|
||||
|
||||
### Load Test Engine (`rate_loadtest.go`)
|
||||
- Worker pool pattern for concurrent peer pairs
|
||||
- Rate-limited pair creation using ticker
|
||||
- Atomic counters for thread-safe metrics collection
|
||||
- Graceful shutdown on context cancellation
|
||||
|
||||
### Test Suite
|
||||
- `loadtest_test.go`: Single pair validation test
|
||||
- `rate_loadtest_test.go`: Multiple rate-based load tests and benchmarks
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Message Flow
|
||||
1. Create sender and receiver clients with unique IDs
|
||||
2. Both clients connect to signal server via bidirectional stream
|
||||
3. Sender sends encrypted message using `Send` RPC
|
||||
4. Signal server forwards message to receiver's stream
|
||||
5. Receiver reads message from stream
|
||||
6. Validate encrypted body size > 0
|
||||
7. Record latency and success metrics
|
||||
|
||||
### Concurrency
|
||||
- Worker pool size = `PairsPerSecond`
|
||||
- Each worker handles multiple peer pairs sequentially
|
||||
- Atomic operations for metrics to avoid lock contention
|
||||
- Channel-based work distribution
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- [x] TLS/HTTPS support for production servers
|
||||
- [x] Automatic reconnection with exponential backoff
|
||||
- [ ] Ramp-up period implementation
|
||||
- [ ] Percentile latency metrics (p50, p95, p99)
|
||||
- [ ] Connection reuse for multiple messages per pair
|
||||
- [ ] Support for custom message payloads
|
||||
- [ ] CSV/JSON metrics export
|
||||
- [ ] Real-time metrics dashboard
|
||||
301
signal/loadtest/client.go
Normal file
301
signal/loadtest/client.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package loadtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
)
|
||||
|
||||
// Client represents a signal client for load testing
|
||||
type Client struct {
|
||||
id string
|
||||
serverURL string
|
||||
config *ClientConfig
|
||||
conn *grpc.ClientConn
|
||||
client proto.SignalExchangeClient
|
||||
stream proto.SignalExchange_ConnectStreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
msgChannel chan *proto.EncryptedMessage
|
||||
|
||||
mu sync.RWMutex
|
||||
reconnectCount int64
|
||||
connected bool
|
||||
receiverStarted bool
|
||||
}
|
||||
|
||||
// ClientConfig holds optional configuration for the client
|
||||
type ClientConfig struct {
|
||||
InsecureSkipVerify bool
|
||||
EnableReconnect bool
|
||||
MaxReconnectDelay time.Duration
|
||||
InitialRetryDelay time.Duration
|
||||
}
|
||||
|
||||
// NewClient creates a new signal client for load testing
|
||||
func NewClient(serverURL, peerID string) (*Client, error) {
|
||||
return NewClientWithConfig(serverURL, peerID, nil)
|
||||
}
|
||||
|
||||
// NewClientWithConfig creates a new signal client with custom TLS configuration
|
||||
func NewClientWithConfig(serverURL, peerID string, config *ClientConfig) (*Client, error) {
|
||||
if config == nil {
|
||||
config = &ClientConfig{}
|
||||
}
|
||||
|
||||
// Set default reconnect delays if not specified
|
||||
if config.EnableReconnect {
|
||||
if config.InitialRetryDelay == 0 {
|
||||
config.InitialRetryDelay = 100 * time.Millisecond
|
||||
}
|
||||
if config.MaxReconnectDelay == 0 {
|
||||
config.MaxReconnectDelay = 30 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
addr, opts, err := parseServerURL(serverURL, config.InsecureSkipVerify)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse server URL: %w", err)
|
||||
}
|
||||
|
||||
conn, err := grpc.Dial(addr, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial server: %w", err)
|
||||
}
|
||||
|
||||
client := proto.NewSignalExchangeClient(conn)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &Client{
|
||||
id: peerID,
|
||||
serverURL: serverURL,
|
||||
config: config,
|
||||
conn: conn,
|
||||
client: client,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
msgChannel: make(chan *proto.EncryptedMessage, 10),
|
||||
connected: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect establishes a stream connection to the signal server
|
||||
func (c *Client) Connect() error {
|
||||
md := metadata.New(map[string]string{proto.HeaderId: c.id})
|
||||
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
||||
|
||||
stream, err := c.client.ConnectStream(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect stream: %w", err)
|
||||
}
|
||||
|
||||
if _, err := stream.Header(); err != nil {
|
||||
return fmt.Errorf("receive header: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.stream = stream
|
||||
c.connected = true
|
||||
if !c.receiverStarted {
|
||||
c.receiverStarted = true
|
||||
c.mu.Unlock()
|
||||
go c.receiveMessages()
|
||||
} else {
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconnectStream reconnects the stream without starting a new receiver goroutine
|
||||
func (c *Client) reconnectStream() error {
|
||||
if !c.config.EnableReconnect {
|
||||
return fmt.Errorf("reconnect disabled")
|
||||
}
|
||||
|
||||
delay := c.config.InitialRetryDelay
|
||||
attempt := 0
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return c.ctx.Err()
|
||||
case <-time.After(delay):
|
||||
attempt++
|
||||
log.Debugf("Client %s reconnect attempt %d (delay: %v)", c.id, attempt, delay)
|
||||
|
||||
md := metadata.New(map[string]string{proto.HeaderId: c.id})
|
||||
ctx := metadata.NewOutgoingContext(c.ctx, md)
|
||||
|
||||
stream, err := c.client.ConnectStream(ctx)
|
||||
if err != nil {
|
||||
log.Debugf("Client %s reconnect attempt %d failed: %v", c.id, attempt, err)
|
||||
delay *= 2
|
||||
if delay > c.config.MaxReconnectDelay {
|
||||
delay = c.config.MaxReconnectDelay
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := stream.Header(); err != nil {
|
||||
log.Debugf("Client %s reconnect header failed: %v", c.id, err)
|
||||
delay *= 2
|
||||
if delay > c.config.MaxReconnectDelay {
|
||||
delay = c.config.MaxReconnectDelay
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.stream = stream
|
||||
c.connected = true
|
||||
c.reconnectCount++
|
||||
c.mu.Unlock()
|
||||
|
||||
log.Debugf("Client %s reconnected successfully (attempt %d, total reconnects: %d)",
|
||||
c.id, attempt, c.reconnectCount)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessage sends an encrypted message to a remote peer using the Send RPC
|
||||
func (c *Client) SendMessage(remotePeerID string, body []byte) error {
|
||||
msg := &proto.EncryptedMessage{
|
||||
Key: c.id,
|
||||
RemoteKey: remotePeerID,
|
||||
Body: body,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.client.Send(ctx, msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send message: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReceiveMessage waits for and returns the next message
|
||||
func (c *Client) ReceiveMessage() (*proto.EncryptedMessage, error) {
|
||||
select {
|
||||
case msg := <-c.msgChannel:
|
||||
return msg, nil
|
||||
case <-c.ctx.Done():
|
||||
return nil, c.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the client connection
|
||||
func (c *Client) Close() error {
|
||||
c.cancel()
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) receiveMessages() {
|
||||
for {
|
||||
c.mu.RLock()
|
||||
stream := c.stream
|
||||
c.mu.RUnlock()
|
||||
|
||||
if stream == nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
// Check if context is cancelled before attempting reconnection
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.connected = false
|
||||
c.mu.Unlock()
|
||||
|
||||
log.Debugf("Client %s receive error: %v", c.id, err)
|
||||
|
||||
// Attempt reconnection if enabled
|
||||
if c.config.EnableReconnect {
|
||||
if reconnectErr := c.reconnectStream(); reconnectErr != nil {
|
||||
log.Debugf("Client %s reconnection failed: %v", c.id, reconnectErr)
|
||||
return
|
||||
}
|
||||
// Successfully reconnected, continue receiving
|
||||
continue
|
||||
}
|
||||
|
||||
// Reconnect disabled, exit
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case c.msgChannel <- msg:
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsConnected returns whether the client is currently connected
|
||||
func (c *Client) IsConnected() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.connected
|
||||
}
|
||||
|
||||
// GetReconnectCount returns the number of reconnections
|
||||
func (c *Client) GetReconnectCount() int64 {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.reconnectCount
|
||||
}
|
||||
|
||||
func parseServerURL(serverURL string, insecureSkipVerify bool) (string, []grpc.DialOption, error) {
|
||||
serverURL = strings.TrimSpace(serverURL)
|
||||
if serverURL == "" {
|
||||
return "", nil, fmt.Errorf("server URL is empty")
|
||||
}
|
||||
|
||||
var addr string
|
||||
var opts []grpc.DialOption
|
||||
|
||||
if strings.HasPrefix(serverURL, "https://") {
|
||||
addr = strings.TrimPrefix(serverURL, "https://")
|
||||
tlsConfig := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
}
|
||||
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
|
||||
} else if strings.HasPrefix(serverURL, "http://") {
|
||||
addr = strings.TrimPrefix(serverURL, "http://")
|
||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
} else {
|
||||
addr = serverURL
|
||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
}
|
||||
|
||||
if !strings.Contains(addr, ":") {
|
||||
return "", nil, fmt.Errorf("server URL must include port")
|
||||
}
|
||||
|
||||
return addr, opts, nil
|
||||
}
|
||||
128
signal/loadtest/cmd/signal-loadtest/integration_test.go
Normal file
128
signal/loadtest/cmd/signal-loadtest/integration_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
func TestCLI_SingleMessage(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServer(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
cmd := exec.Command("go", "run", "main.go",
|
||||
"-server", serverAddr,
|
||||
"-pairs-per-sec", "3",
|
||||
"-total-pairs", "5",
|
||||
"-message-size", "50",
|
||||
"-log-level", "warn")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "CLI should execute successfully")
|
||||
|
||||
outputStr := string(output)
|
||||
require.Contains(t, outputStr, "Load Test Report")
|
||||
require.Contains(t, outputStr, "Total Pairs Sent: 5")
|
||||
require.Contains(t, outputStr, "Successful Exchanges: 5")
|
||||
t.Logf("Output:\n%s", outputStr)
|
||||
}
|
||||
|
||||
func TestCLI_ContinuousExchange(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping continuous exchange CLI test in short mode")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServer(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
cmd := exec.Command("go", "run", "main.go",
|
||||
"-server", serverAddr,
|
||||
"-pairs-per-sec", "2",
|
||||
"-total-pairs", "3",
|
||||
"-message-size", "100",
|
||||
"-exchange-duration", "3s",
|
||||
"-message-interval", "100ms",
|
||||
"-log-level", "warn")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "CLI should execute successfully")
|
||||
|
||||
outputStr := string(output)
|
||||
require.Contains(t, outputStr, "Load Test Report")
|
||||
require.Contains(t, outputStr, "Total Pairs Sent: 3")
|
||||
require.Contains(t, outputStr, "Successful Exchanges: 3")
|
||||
t.Logf("Output:\n%s", outputStr)
|
||||
}
|
||||
|
||||
func TestCLI_InvalidConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
}{
|
||||
{
|
||||
name: "negative pairs",
|
||||
args: []string{"-pairs-per-sec", "-1"},
|
||||
},
|
||||
{
|
||||
name: "zero total pairs",
|
||||
args: []string{"-total-pairs", "0"},
|
||||
},
|
||||
{
|
||||
name: "negative message size",
|
||||
args: []string{"-message-size", "-100"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
args := append([]string{"run", "main.go"}, tt.args...)
|
||||
cmd := exec.Command("go", args...)
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.Error(t, err, "Should fail with invalid config")
|
||||
require.Contains(t, string(output), "Configuration error")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func startTestSignalServer(t *testing.T, ctx context.Context) (*grpc.Server, string) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
signalServer, err := server.NewServer(ctx, otel.Meter("cli-test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(listener); err != nil {
|
||||
t.Logf("Server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
165
signal/loadtest/cmd/signal-loadtest/main.go
Normal file
165
signal/loadtest/cmd/signal-loadtest/main.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/signal/loadtest"
|
||||
)
|
||||
|
||||
var (
|
||||
serverURL string
|
||||
pairsPerSecond int
|
||||
totalPairs int
|
||||
messageSize int
|
||||
testDuration time.Duration
|
||||
exchangeDuration time.Duration
|
||||
messageInterval time.Duration
|
||||
insecureSkipVerify bool
|
||||
workerPoolSize int
|
||||
channelBufferSize int
|
||||
reportInterval int
|
||||
logLevel string
|
||||
enableReconnect bool
|
||||
maxReconnectDelay time.Duration
|
||||
initialRetryDelay time.Duration
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&serverURL, "server", "http://localhost:10000", "Signal server URL (http:// or https://)")
|
||||
flag.IntVar(&pairsPerSecond, "pairs-per-sec", 10, "Number of peer pairs to create per second")
|
||||
flag.IntVar(&totalPairs, "total-pairs", 100, "Total number of peer pairs to create")
|
||||
flag.IntVar(&messageSize, "message-size", 100, "Size of test message in bytes")
|
||||
flag.DurationVar(&testDuration, "test-duration", 0, "Maximum test duration (0 = unlimited)")
|
||||
flag.DurationVar(&exchangeDuration, "exchange-duration", 0, "Duration for continuous message exchange per pair (0 = single message)")
|
||||
flag.DurationVar(&messageInterval, "message-interval", 100*time.Millisecond, "Interval between messages in continuous mode")
|
||||
flag.BoolVar(&insecureSkipVerify, "insecure-skip-verify", false, "Skip TLS certificate verification (use with self-signed certificates)")
|
||||
flag.IntVar(&workerPoolSize, "worker-pool-size", 0, "Number of worker goroutines (0 = auto: pairs-per-sec * 2)")
|
||||
flag.IntVar(&channelBufferSize, "channel-buffer-size", 0, "Channel buffer size (0 = auto: pairs-per-sec * 4)")
|
||||
flag.IntVar(&reportInterval, "report-interval", 10000, "Report progress every N messages (0 = no periodic reports)")
|
||||
flag.StringVar(&logLevel, "log-level", "info", "Log level (trace, debug, info, warn, error)")
|
||||
flag.BoolVar(&enableReconnect, "enable-reconnect", true, "Enable automatic reconnection on connection loss")
|
||||
flag.DurationVar(&maxReconnectDelay, "max-reconnect-delay", 30*time.Second, "Maximum delay between reconnection attempts")
|
||||
flag.DurationVar(&initialRetryDelay, "initial-retry-delay", 100*time.Millisecond, "Initial delay before first reconnection attempt")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
level, err := log.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Invalid log level: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
log.SetLevel(level)
|
||||
|
||||
config := loadtest.LoadTestConfig{
|
||||
ServerURL: serverURL,
|
||||
PairsPerSecond: pairsPerSecond,
|
||||
TotalPairs: totalPairs,
|
||||
MessageSize: messageSize,
|
||||
TestDuration: testDuration,
|
||||
ExchangeDuration: exchangeDuration,
|
||||
MessageInterval: messageInterval,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
WorkerPoolSize: workerPoolSize,
|
||||
ChannelBufferSize: channelBufferSize,
|
||||
ReportInterval: reportInterval,
|
||||
EnableReconnect: enableReconnect,
|
||||
MaxReconnectDelay: maxReconnectDelay,
|
||||
InitialRetryDelay: initialRetryDelay,
|
||||
}
|
||||
|
||||
if err := validateConfig(config); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
log.Infof("Signal Load Test Configuration:")
|
||||
log.Infof(" Server URL: %s", config.ServerURL)
|
||||
log.Infof(" Pairs per second: %d", config.PairsPerSecond)
|
||||
log.Infof(" Total pairs: %d", config.TotalPairs)
|
||||
log.Infof(" Message size: %d bytes", config.MessageSize)
|
||||
if config.InsecureSkipVerify {
|
||||
log.Warnf(" TLS certificate verification: DISABLED (insecure)")
|
||||
}
|
||||
if config.TestDuration > 0 {
|
||||
log.Infof(" Test duration: %v", config.TestDuration)
|
||||
}
|
||||
if config.ExchangeDuration > 0 {
|
||||
log.Infof(" Exchange duration: %v", config.ExchangeDuration)
|
||||
log.Infof(" Message interval: %v", config.MessageInterval)
|
||||
} else {
|
||||
log.Infof(" Mode: Single message exchange")
|
||||
}
|
||||
if config.EnableReconnect {
|
||||
log.Infof(" Reconnection: ENABLED")
|
||||
log.Infof(" Initial retry delay: %v", config.InitialRetryDelay)
|
||||
log.Infof(" Max reconnect delay: %v", config.MaxReconnectDelay)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// Set up signal handler for graceful shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
lt := loadtest.NewLoadTestWithContext(ctx, config)
|
||||
|
||||
// Run load test in a goroutine
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- lt.Run()
|
||||
}()
|
||||
|
||||
// Wait for completion or signal
|
||||
select {
|
||||
case <-sigChan:
|
||||
log.Warnf("\nReceived interrupt signal, shutting down gracefully...")
|
||||
cancel()
|
||||
// Wait a bit for graceful shutdown
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Warnf("Shutdown timeout, forcing exit")
|
||||
}
|
||||
case err := <-done:
|
||||
if err != nil && err != context.Canceled {
|
||||
log.Errorf("Load test failed: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
metrics := lt.GetMetrics()
|
||||
fmt.Println() // Add blank line before report
|
||||
metrics.PrintReport()
|
||||
}
|
||||
|
||||
func validateConfig(config loadtest.LoadTestConfig) error {
|
||||
if config.ServerURL == "" {
|
||||
return fmt.Errorf("server URL is required")
|
||||
}
|
||||
if config.PairsPerSecond <= 0 {
|
||||
return fmt.Errorf("pairs-per-sec must be greater than 0")
|
||||
}
|
||||
if config.TotalPairs <= 0 {
|
||||
return fmt.Errorf("total-pairs must be greater than 0")
|
||||
}
|
||||
if config.MessageSize <= 0 {
|
||||
return fmt.Errorf("message-size must be greater than 0")
|
||||
}
|
||||
if config.MessageInterval <= 0 {
|
||||
return fmt.Errorf("message-interval must be greater than 0")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
40
signal/loadtest/cmd/signal-loadtest/test.sh
Normal file
40
signal/loadtest/cmd/signal-loadtest/test.sh
Normal file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
echo "Building signal-loadtest binary..."
|
||||
go build -o signal-loadtest
|
||||
|
||||
echo ""
|
||||
echo "=== Test 1: Single message exchange (5 pairs) ==="
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 5 \
|
||||
-total-pairs 5 \
|
||||
-message-size 50 \
|
||||
-log-level info
|
||||
|
||||
echo ""
|
||||
echo "=== Test 2: Continuous exchange (3 pairs, 5 seconds) ==="
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 3 \
|
||||
-total-pairs 3 \
|
||||
-message-size 100 \
|
||||
-exchange-duration 5s \
|
||||
-message-interval 200ms \
|
||||
-log-level info
|
||||
|
||||
echo ""
|
||||
echo "=== Test 3: Progress reporting (10 pairs, 10s, report every 100 messages) ==="
|
||||
./signal-loadtest \
|
||||
-server http://localhost:10000 \
|
||||
-pairs-per-sec 10 \
|
||||
-total-pairs 10 \
|
||||
-message-size 100 \
|
||||
-exchange-duration 10s \
|
||||
-message-interval 100ms \
|
||||
-report-interval 100 \
|
||||
-log-level info
|
||||
|
||||
echo ""
|
||||
echo "All tests completed!"
|
||||
91
signal/loadtest/loadtest_test.go
Normal file
91
signal/loadtest/loadtest_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package loadtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
func TestSignalLoadTest_SinglePair(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServer(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
sender, err := NewClient(serverAddr, "sender-peer-id")
|
||||
require.NoError(t, err)
|
||||
defer sender.Close()
|
||||
|
||||
receiver, err := NewClient(serverAddr, "receiver-peer-id")
|
||||
require.NoError(t, err)
|
||||
defer receiver.Close()
|
||||
|
||||
err = sender.Connect()
|
||||
require.NoError(t, err, "Sender should connect successfully")
|
||||
|
||||
err = receiver.Connect()
|
||||
require.NoError(t, err, "Receiver should connect successfully")
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
testMessage := []byte("test message payload")
|
||||
|
||||
t.Log("Sending message from sender to receiver")
|
||||
err = sender.SendMessage("receiver-peer-id", testMessage)
|
||||
require.NoError(t, err, "Sender should send message successfully")
|
||||
|
||||
t.Log("Waiting for receiver to receive message")
|
||||
|
||||
receiveDone := make(chan struct{})
|
||||
var msg *proto.EncryptedMessage
|
||||
var receiveErr error
|
||||
|
||||
go func() {
|
||||
msg, receiveErr = receiver.ReceiveMessage()
|
||||
close(receiveDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-receiveDone:
|
||||
require.NoError(t, receiveErr, "Receiver should receive message")
|
||||
require.NotNil(t, msg, "Received message should not be nil")
|
||||
require.Greater(t, len(msg.Body), 0, "Encrypted message body size should be greater than 0")
|
||||
require.Equal(t, "sender-peer-id", msg.Key)
|
||||
require.Equal(t, "receiver-peer-id", msg.RemoteKey)
|
||||
t.Log("Message received successfully")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Timeout waiting for message")
|
||||
}
|
||||
}
|
||||
|
||||
func startTestSignalServer(t *testing.T, ctx context.Context) (*grpc.Server, string) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
signalServer, err := server.NewServer(ctx, otel.Meter("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(listener); err != nil {
|
||||
t.Logf("Server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
|
||||
}
|
||||
461
signal/loadtest/rate_loadtest.go
Normal file
461
signal/loadtest/rate_loadtest.go
Normal file
@@ -0,0 +1,461 @@
|
||||
package loadtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// LoadTestConfig configuration for the load test
|
||||
type LoadTestConfig struct {
|
||||
IDPrefix string
|
||||
ServerURL string
|
||||
PairsPerSecond int
|
||||
TotalPairs int
|
||||
MessageSize int
|
||||
TestDuration time.Duration
|
||||
ExchangeDuration time.Duration
|
||||
MessageInterval time.Duration
|
||||
RampUpDuration time.Duration
|
||||
InsecureSkipVerify bool
|
||||
WorkerPoolSize int
|
||||
ChannelBufferSize int
|
||||
ReportInterval int // Report progress every N messages (0 = no periodic reports)
|
||||
EnableReconnect bool
|
||||
MaxReconnectDelay time.Duration
|
||||
InitialRetryDelay time.Duration
|
||||
}
|
||||
|
||||
// LoadTestMetrics metrics collected during the load test
|
||||
type LoadTestMetrics struct {
|
||||
TotalPairsSent atomic.Int64
|
||||
TotalMessagesExchanged atomic.Int64
|
||||
TotalErrors atomic.Int64
|
||||
SuccessfulExchanges atomic.Int64
|
||||
FailedExchanges atomic.Int64
|
||||
ActivePairs atomic.Int64
|
||||
TotalReconnections atomic.Int64
|
||||
|
||||
mu sync.Mutex
|
||||
latencies []time.Duration
|
||||
startTime time.Time
|
||||
endTime time.Time
|
||||
}
|
||||
|
||||
// PeerPair represents a sender-receiver pair
|
||||
type PeerPair struct {
|
||||
sender *Client
|
||||
receiver *Client
|
||||
pairID int
|
||||
}
|
||||
|
||||
// LoadTest manages the load test execution
|
||||
type LoadTest struct {
|
||||
config LoadTestConfig
|
||||
metrics *LoadTestMetrics
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
reporterCtx context.Context
|
||||
reporterCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewLoadTest creates a new load test instance
|
||||
func NewLoadTest(config LoadTestConfig) *LoadTest {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return newLoadTestWithContext(ctx, cancel, config)
|
||||
}
|
||||
|
||||
// NewLoadTestWithContext creates a new load test instance with a custom context
|
||||
func NewLoadTestWithContext(ctx context.Context, config LoadTestConfig) *LoadTest {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return newLoadTestWithContext(ctx, cancel, config)
|
||||
}
|
||||
|
||||
func newLoadTestWithContext(ctx context.Context, cancel context.CancelFunc, config LoadTestConfig) *LoadTest {
|
||||
reporterCtx, reporterCancel := context.WithCancel(ctx)
|
||||
config.IDPrefix = fmt.Sprintf("%d-", time.Now().UnixNano())
|
||||
return &LoadTest{
|
||||
config: config,
|
||||
metrics: &LoadTestMetrics{},
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
reporterCtx: reporterCtx,
|
||||
reporterCancel: reporterCancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Run executes the load test
|
||||
func (lt *LoadTest) Run() error {
|
||||
lt.metrics.startTime = time.Now()
|
||||
defer func() {
|
||||
lt.metrics.endTime = time.Now()
|
||||
}()
|
||||
|
||||
exchangeInfo := "single message"
|
||||
if lt.config.ExchangeDuration > 0 {
|
||||
exchangeInfo = fmt.Sprintf("continuous for %v", lt.config.ExchangeDuration)
|
||||
}
|
||||
|
||||
workerPoolSize := lt.config.WorkerPoolSize
|
||||
if workerPoolSize == 0 {
|
||||
workerPoolSize = lt.config.PairsPerSecond * 2
|
||||
}
|
||||
|
||||
channelBufferSize := lt.config.ChannelBufferSize
|
||||
if channelBufferSize == 0 {
|
||||
channelBufferSize = lt.config.PairsPerSecond * 4
|
||||
}
|
||||
|
||||
log.Infof("Starting load test: %d pairs/sec, %d total pairs, message size: %d bytes, exchange: %s",
|
||||
lt.config.PairsPerSecond, lt.config.TotalPairs, lt.config.MessageSize, exchangeInfo)
|
||||
log.Infof("Worker pool size: %d, channel buffer: %d", workerPoolSize, channelBufferSize)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var reporterWg sync.WaitGroup
|
||||
pairChan := make(chan int, channelBufferSize)
|
||||
|
||||
// Start progress reporter if configured
|
||||
if lt.config.ReportInterval > 0 {
|
||||
reporterWg.Add(1)
|
||||
go lt.progressReporter(&reporterWg, lt.config.ReportInterval)
|
||||
}
|
||||
|
||||
for i := 0; i < workerPoolSize; i++ {
|
||||
wg.Add(1)
|
||||
go lt.pairWorker(&wg, pairChan)
|
||||
}
|
||||
|
||||
testCtx := lt.ctx
|
||||
if lt.config.TestDuration > 0 {
|
||||
var testCancel context.CancelFunc
|
||||
testCtx, testCancel = context.WithTimeout(lt.ctx, lt.config.TestDuration)
|
||||
defer testCancel()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Second / time.Duration(lt.config.PairsPerSecond))
|
||||
defer ticker.Stop()
|
||||
|
||||
pairsCreated := 0
|
||||
for pairsCreated < lt.config.TotalPairs {
|
||||
select {
|
||||
case <-testCtx.Done():
|
||||
log.Infof("Test duration reached or context cancelled")
|
||||
close(pairChan)
|
||||
wg.Wait()
|
||||
return testCtx.Err()
|
||||
case <-ticker.C:
|
||||
select {
|
||||
case pairChan <- pairsCreated:
|
||||
pairsCreated++
|
||||
default:
|
||||
log.Warnf("Worker pool saturated, skipping pair creation")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("All %d pairs queued, waiting for completion...", pairsCreated)
|
||||
close(pairChan)
|
||||
wg.Wait()
|
||||
|
||||
// Cancel progress reporter context after all work is done and wait for it
|
||||
lt.reporterCancel()
|
||||
reporterWg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lt *LoadTest) pairWorker(wg *sync.WaitGroup, pairChan <-chan int) {
|
||||
defer wg.Done()
|
||||
|
||||
for pairID := range pairChan {
|
||||
lt.metrics.ActivePairs.Add(1)
|
||||
if err := lt.executePairExchange(pairID); err != nil {
|
||||
lt.metrics.TotalErrors.Add(1)
|
||||
lt.metrics.FailedExchanges.Add(1)
|
||||
log.Debugf("Pair %d exchange failed: %v", pairID, err)
|
||||
} else {
|
||||
lt.metrics.SuccessfulExchanges.Add(1)
|
||||
}
|
||||
lt.metrics.ActivePairs.Add(-1)
|
||||
lt.metrics.TotalPairsSent.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (lt *LoadTest) executePairExchange(pairID int) error {
|
||||
senderID := fmt.Sprintf("%ssender-%d", lt.config.IDPrefix, pairID)
|
||||
receiverID := fmt.Sprintf("%sreceiver-%d", lt.config.IDPrefix, pairID)
|
||||
|
||||
clientConfig := &ClientConfig{
|
||||
InsecureSkipVerify: lt.config.InsecureSkipVerify,
|
||||
EnableReconnect: lt.config.EnableReconnect,
|
||||
MaxReconnectDelay: lt.config.MaxReconnectDelay,
|
||||
InitialRetryDelay: lt.config.InitialRetryDelay,
|
||||
}
|
||||
|
||||
sender, err := NewClientWithConfig(lt.config.ServerURL, senderID, clientConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sender: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
sender.Close()
|
||||
lt.metrics.TotalReconnections.Add(sender.GetReconnectCount())
|
||||
}()
|
||||
|
||||
receiver, err := NewClientWithConfig(lt.config.ServerURL, receiverID, clientConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create receiver: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
receiver.Close()
|
||||
lt.metrics.TotalReconnections.Add(receiver.GetReconnectCount())
|
||||
}()
|
||||
|
||||
if err := sender.Connect(); err != nil {
|
||||
return fmt.Errorf("sender connect: %w", err)
|
||||
}
|
||||
|
||||
if err := receiver.Connect(); err != nil {
|
||||
return fmt.Errorf("receiver connect: %w", err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
testMessage := make([]byte, lt.config.MessageSize)
|
||||
for i := range testMessage {
|
||||
testMessage[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
if lt.config.ExchangeDuration > 0 {
|
||||
return lt.continuousExchange(pairID, sender, receiver, receiverID, testMessage)
|
||||
}
|
||||
|
||||
return lt.singleExchange(sender, receiver, receiverID, testMessage)
|
||||
}
|
||||
|
||||
func (lt *LoadTest) singleExchange(sender, receiver *Client, receiverID string, testMessage []byte) error {
|
||||
startTime := time.Now()
|
||||
|
||||
if err := sender.SendMessage(receiverID, testMessage); err != nil {
|
||||
return fmt.Errorf("send message: %w", err)
|
||||
}
|
||||
|
||||
receiveDone := make(chan error, 1)
|
||||
go func() {
|
||||
msg, err := receiver.ReceiveMessage()
|
||||
if err != nil {
|
||||
receiveDone <- err
|
||||
return
|
||||
}
|
||||
if len(msg.Body) == 0 {
|
||||
receiveDone <- fmt.Errorf("empty message body")
|
||||
return
|
||||
}
|
||||
receiveDone <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-receiveDone:
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive message: %w", err)
|
||||
}
|
||||
latency := time.Since(startTime)
|
||||
lt.recordLatency(latency)
|
||||
lt.metrics.TotalMessagesExchanged.Add(1)
|
||||
return nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return fmt.Errorf("timeout waiting for message")
|
||||
case <-lt.ctx.Done():
|
||||
return lt.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (lt *LoadTest) continuousExchange(pairID int, sender, receiver *Client, receiverID string, testMessage []byte) error {
|
||||
exchangeCtx, cancel := context.WithTimeout(lt.ctx, lt.config.ExchangeDuration)
|
||||
defer cancel()
|
||||
|
||||
messageInterval := lt.config.MessageInterval
|
||||
if messageInterval == 0 {
|
||||
messageInterval = 100 * time.Millisecond
|
||||
}
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := lt.receiverLoop(exchangeCtx, receiver, pairID); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := lt.senderLoop(exchangeCtx, sender, receiverID, testMessage, messageInterval); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
|
||||
select {
|
||||
case errChan <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (lt *LoadTest) senderLoop(ctx context.Context, sender *Client, receiverID string, message []byte, interval time.Duration) error {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
startTime := time.Now()
|
||||
if err := sender.SendMessage(receiverID, message); err != nil {
|
||||
lt.metrics.TotalErrors.Add(1)
|
||||
log.Debugf("Send error: %v", err)
|
||||
continue
|
||||
}
|
||||
lt.recordLatency(time.Since(startTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (lt *LoadTest) receiverLoop(ctx context.Context, receiver *Client, pairID int) error {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
select {
|
||||
case msg, ok := <-receiver.msgChannel:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if len(msg.Body) > 0 {
|
||||
lt.metrics.TotalMessagesExchanged.Add(1)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (lt *LoadTest) recordLatency(latency time.Duration) {
|
||||
lt.metrics.mu.Lock()
|
||||
defer lt.metrics.mu.Unlock()
|
||||
lt.metrics.latencies = append(lt.metrics.latencies, latency)
|
||||
}
|
||||
|
||||
// progressReporter prints periodic progress reports
|
||||
func (lt *LoadTest) progressReporter(wg *sync.WaitGroup, interval int) {
|
||||
defer wg.Done()
|
||||
|
||||
lastReported := int64(0)
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-lt.reporterCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
currentMessages := lt.metrics.TotalMessagesExchanged.Load()
|
||||
if currentMessages-lastReported >= int64(interval) {
|
||||
elapsed := time.Since(lt.metrics.startTime)
|
||||
activePairs := lt.metrics.ActivePairs.Load()
|
||||
errors := lt.metrics.TotalErrors.Load()
|
||||
reconnections := lt.metrics.TotalReconnections.Load()
|
||||
|
||||
var msgRate float64
|
||||
if elapsed.Seconds() > 0 {
|
||||
msgRate = float64(currentMessages) / elapsed.Seconds()
|
||||
}
|
||||
|
||||
log.Infof("Progress: %d messages exchanged, %d active pairs, %d errors, %d reconnections, %.2f msg/sec, elapsed: %v",
|
||||
currentMessages, activePairs, errors, reconnections, msgRate, elapsed.Round(time.Second))
|
||||
|
||||
lastReported = (currentMessages / int64(interval)) * int64(interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the load test
|
||||
func (lt *LoadTest) Stop() {
|
||||
lt.cancel()
|
||||
lt.reporterCancel()
|
||||
}
|
||||
|
||||
// GetMetrics returns the collected metrics
|
||||
func (lt *LoadTest) GetMetrics() *LoadTestMetrics {
|
||||
return lt.metrics
|
||||
}
|
||||
|
||||
// PrintReport prints a summary report of the test results
|
||||
func (m *LoadTestMetrics) PrintReport() {
|
||||
duration := m.endTime.Sub(m.startTime)
|
||||
|
||||
fmt.Println("\n=== Load Test Report ===")
|
||||
fmt.Printf("Test Duration: %v\n", duration)
|
||||
fmt.Printf("Total Pairs Sent: %d\n", m.TotalPairsSent.Load())
|
||||
fmt.Printf("Successful Exchanges: %d\n", m.SuccessfulExchanges.Load())
|
||||
fmt.Printf("Failed Exchanges: %d\n", m.FailedExchanges.Load())
|
||||
fmt.Printf("Total Messages Exchanged: %d\n", m.TotalMessagesExchanged.Load())
|
||||
fmt.Printf("Total Errors: %d\n", m.TotalErrors.Load())
|
||||
|
||||
reconnections := m.TotalReconnections.Load()
|
||||
if reconnections > 0 {
|
||||
fmt.Printf("Total Reconnections: %d\n", reconnections)
|
||||
}
|
||||
|
||||
if duration.Seconds() > 0 {
|
||||
throughput := float64(m.SuccessfulExchanges.Load()) / duration.Seconds()
|
||||
fmt.Printf("Throughput: %.2f pairs/sec\n", throughput)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
latencies := m.latencies
|
||||
m.mu.Unlock()
|
||||
|
||||
if len(latencies) > 0 {
|
||||
var total time.Duration
|
||||
minLatency := latencies[0]
|
||||
maxLatency := latencies[0]
|
||||
|
||||
for _, lat := range latencies {
|
||||
total += lat
|
||||
if lat < minLatency {
|
||||
minLatency = lat
|
||||
}
|
||||
if lat > maxLatency {
|
||||
maxLatency = lat
|
||||
}
|
||||
}
|
||||
|
||||
avg := total / time.Duration(len(latencies))
|
||||
fmt.Printf("\nLatency Statistics:\n")
|
||||
fmt.Printf(" Min: %v\n", minLatency)
|
||||
fmt.Printf(" Max: %v\n", maxLatency)
|
||||
fmt.Printf(" Avg: %v\n", avg)
|
||||
}
|
||||
fmt.Println("========================")
|
||||
}
|
||||
305
signal/loadtest/rate_loadtest_test.go
Normal file
305
signal/loadtest/rate_loadtest_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package loadtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
func TestLoadTest_10PairsPerSecond(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping load test in short mode")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 10,
|
||||
TotalPairs: 50,
|
||||
MessageSize: 100,
|
||||
TestDuration: 30 * time.Second,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
require.Equal(t, int64(50), metrics.TotalPairsSent.Load(), "Should send all 50 pairs")
|
||||
require.Greater(t, metrics.SuccessfulExchanges.Load(), int64(0), "Should have successful exchanges")
|
||||
require.Equal(t, metrics.TotalMessagesExchanged.Load(), metrics.SuccessfulExchanges.Load(), "Messages exchanged should match successful exchanges")
|
||||
}
|
||||
|
||||
func TestLoadTest_20PairsPerSecond(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping load test in short mode")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 20,
|
||||
TotalPairs: 100,
|
||||
MessageSize: 500,
|
||||
TestDuration: 30 * time.Second,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
require.Equal(t, int64(100), metrics.TotalPairsSent.Load(), "Should send all 100 pairs")
|
||||
require.Greater(t, metrics.SuccessfulExchanges.Load(), int64(0), "Should have successful exchanges")
|
||||
}
|
||||
|
||||
func TestLoadTest_SmallBurst(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 5,
|
||||
TotalPairs: 10,
|
||||
MessageSize: 50,
|
||||
TestDuration: 10 * time.Second,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
require.Equal(t, int64(10), metrics.TotalPairsSent.Load())
|
||||
require.Greater(t, metrics.SuccessfulExchanges.Load(), int64(5), "At least 50% success rate")
|
||||
require.Less(t, metrics.FailedExchanges.Load(), int64(5), "Less than 50% failure rate")
|
||||
}
|
||||
|
||||
func TestLoadTest_ContinuousExchange_30Seconds(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping continuous exchange test in short mode")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 5,
|
||||
TotalPairs: 10,
|
||||
MessageSize: 100,
|
||||
ExchangeDuration: 30 * time.Second,
|
||||
MessageInterval: 100 * time.Millisecond,
|
||||
TestDuration: 2 * time.Minute,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
require.Equal(t, int64(10), metrics.TotalPairsSent.Load())
|
||||
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(2000), "Should exchange many messages over 30 seconds")
|
||||
}
|
||||
|
||||
func TestLoadTest_ContinuousExchange_10Minutes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping long continuous exchange test in short mode")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 10,
|
||||
TotalPairs: 20,
|
||||
MessageSize: 200,
|
||||
ExchangeDuration: 10 * time.Minute,
|
||||
MessageInterval: 200 * time.Millisecond,
|
||||
TestDuration: 15 * time.Minute,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
require.Equal(t, int64(20), metrics.TotalPairsSent.Load())
|
||||
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(50000), "Should exchange many messages over 10 minutes")
|
||||
}
|
||||
|
||||
func TestLoadTest_ContinuousExchange_ShortBurst(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 3,
|
||||
TotalPairs: 5,
|
||||
MessageSize: 50,
|
||||
ExchangeDuration: 3 * time.Second,
|
||||
MessageInterval: 100 * time.Millisecond,
|
||||
TestDuration: 10 * time.Second,
|
||||
ReportInterval: 50, // Report every 50 messages for testing
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
require.Equal(t, int64(5), metrics.TotalPairsSent.Load())
|
||||
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(100), "Should exchange multiple messages in 3 seconds")
|
||||
require.Equal(t, int64(5), metrics.SuccessfulExchanges.Load(), "All pairs should complete successfully")
|
||||
}
|
||||
|
||||
func TestLoadTest_ReconnectionConfig(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 3,
|
||||
TotalPairs: 5,
|
||||
MessageSize: 50,
|
||||
ExchangeDuration: 2 * time.Second,
|
||||
MessageInterval: 200 * time.Millisecond,
|
||||
TestDuration: 5 * time.Second,
|
||||
EnableReconnect: true,
|
||||
InitialRetryDelay: 100 * time.Millisecond,
|
||||
MaxReconnectDelay: 2 * time.Second,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
err := loadTest.Run()
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
metrics.PrintReport()
|
||||
|
||||
// Test should complete successfully with reconnection enabled
|
||||
require.Equal(t, int64(5), metrics.TotalPairsSent.Load())
|
||||
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(0), "Should have exchanged messages")
|
||||
require.Equal(t, int64(5), metrics.SuccessfulExchanges.Load(), "All pairs should complete successfully")
|
||||
|
||||
// Reconnections counter should exist (even if zero for this stable test)
|
||||
reconnections := metrics.TotalReconnections.Load()
|
||||
require.GreaterOrEqual(t, reconnections, int64(0), "Reconnections metric should be tracked")
|
||||
}
|
||||
|
||||
func BenchmarkLoadTest_Throughput(b *testing.B) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
grpcServer, serverAddr := startBenchSignalServer(b, ctx)
|
||||
defer grpcServer.Stop()
|
||||
|
||||
b.Run("5pairs-per-sec", func(b *testing.B) {
|
||||
config := LoadTestConfig{
|
||||
ServerURL: serverAddr,
|
||||
PairsPerSecond: 5,
|
||||
TotalPairs: b.N,
|
||||
MessageSize: 100,
|
||||
}
|
||||
|
||||
loadTest := NewLoadTest(config)
|
||||
b.ResetTimer()
|
||||
_ = loadTest.Run()
|
||||
b.StopTimer()
|
||||
|
||||
metrics := loadTest.GetMetrics()
|
||||
b.ReportMetric(float64(metrics.SuccessfulExchanges.Load()), "successful")
|
||||
b.ReportMetric(float64(metrics.FailedExchanges.Load()), "failed")
|
||||
})
|
||||
}
|
||||
|
||||
func startTestSignalServerForLoad(t *testing.T, ctx context.Context) (*grpc.Server, string) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
signalServer, err := server.NewServer(ctx, otel.Meter("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(listener); err != nil {
|
||||
t.Logf("Server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
|
||||
}
|
||||
|
||||
func startBenchSignalServer(b *testing.B, ctx context.Context) (*grpc.Server, string) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(b, err)
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
signalServer, err := server.NewServer(ctx, otel.Meter("bench"))
|
||||
require.NoError(b, err)
|
||||
|
||||
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
|
||||
|
||||
go func() {
|
||||
if err := grpcServer.Serve(listener); err != nil {
|
||||
b.Logf("Server stopped: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
|
||||
}
|
||||
@@ -2,42 +2,41 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/netbirdio/netbird/util/wsproxy"
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 10 * time.Second
|
||||
bufferSize = 32 * 1024
|
||||
bufferSize = 32 * 1024
|
||||
ioTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Config contains the configuration for the WebSocket proxy.
|
||||
type Config struct {
|
||||
LocalGRPCAddr netip.AddrPort
|
||||
Handler http.Handler
|
||||
Path string
|
||||
MetricsRecorder MetricsRecorder
|
||||
}
|
||||
|
||||
// Proxy handles WebSocket to TCP proxying for gRPC connections.
|
||||
// Proxy handles WebSocket to gRPC handler proxying.
|
||||
type Proxy struct {
|
||||
config Config
|
||||
metrics MetricsRecorder
|
||||
}
|
||||
|
||||
// New creates a new WebSocket proxy instance with optional configuration
|
||||
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy {
|
||||
func New(handler http.Handler, opts ...Option) *Proxy {
|
||||
config := Config{
|
||||
LocalGRPCAddr: localGRPCAddr,
|
||||
Handler: handler,
|
||||
Path: wsproxy.ProxyPath,
|
||||
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
|
||||
}
|
||||
@@ -63,7 +62,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
p.metrics.RecordConnection(ctx)
|
||||
defer p.metrics.RecordDisconnection(ctx)
|
||||
|
||||
log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr)
|
||||
log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr)
|
||||
acceptOptions := &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
}
|
||||
@@ -75,71 +74,41 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil {
|
||||
log.Debugf("Failed to close WebSocket: %v", err)
|
||||
}
|
||||
_ = wsConn.Close(websocket.StatusNormalClosure, "")
|
||||
}()
|
||||
|
||||
log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr)
|
||||
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
|
||||
if err != nil {
|
||||
p.metrics.RecordError(ctx, "tcp_dial_failed")
|
||||
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
|
||||
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
|
||||
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer func() {
|
||||
if err := tcpConn.Close(); err != nil {
|
||||
log.Debugf("Failed to close TCP connection: %v", err)
|
||||
}
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
}()
|
||||
|
||||
log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr)
|
||||
log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr)
|
||||
|
||||
p.proxyData(ctx, wsConn, tcpConn)
|
||||
go func() {
|
||||
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
|
||||
Context: ctx,
|
||||
Handler: p.config.Handler,
|
||||
})
|
||||
}()
|
||||
|
||||
p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
|
||||
}
|
||||
|
||||
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
|
||||
proxyCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn)
|
||||
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn)
|
||||
go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
|
||||
go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Tracef("Proxy data transfer completed, both goroutines terminated")
|
||||
case <-proxyCtx.Done():
|
||||
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
|
||||
|
||||
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
|
||||
log.Tracef("Error closing WebSocket during cancellation: %v", err)
|
||||
}
|
||||
if err := tcpConn.Close(); err != nil {
|
||||
log.Tracef("Error closing TCP connection during cancellation: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
log.Tracef("Goroutines terminated after forced connection closure")
|
||||
case <-time.After(2 * time.Second):
|
||||
log.Tracef("Goroutines did not terminate within timeout after connection closure")
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||
func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
@@ -148,80 +117,73 @@ func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync
|
||||
if err != nil {
|
||||
switch {
|
||||
case ctx.Err() != nil:
|
||||
log.Debugf("wsToTCP goroutine terminating due to context cancellation")
|
||||
case websocket.CloseStatus(err) == websocket.StatusNormalClosure:
|
||||
log.Debugf("WebSocket closed normally")
|
||||
log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr)
|
||||
case websocket.CloseStatus(err) != -1:
|
||||
log.Debugf("WebSocket from %s disconnected", clientAddr)
|
||||
default:
|
||||
p.metrics.RecordError(ctx, "websocket_read_error")
|
||||
log.Errorf("WebSocket read error: %v", err)
|
||||
log.Debugf("WebSocket read error from %s: %v", clientAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if msgType != websocket.MessageBinary {
|
||||
log.Warnf("Unexpected WebSocket message type: %v", msgType)
|
||||
log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType)
|
||||
continue
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write")
|
||||
log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write")
|
||||
return
|
||||
}
|
||||
|
||||
if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Debugf("Failed to set TCP write deadline: %v", err)
|
||||
if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil {
|
||||
log.Debugf("Failed to set pipe write deadline: %v", err)
|
||||
}
|
||||
|
||||
n, err := tcpConn.Write(data)
|
||||
n, err := pipeConn.Write(data)
|
||||
if err != nil {
|
||||
p.metrics.RecordError(ctx, "tcp_write_error")
|
||||
log.Errorf("TCP write error: %v", err)
|
||||
p.metrics.RecordError(ctx, "pipe_write_error")
|
||||
log.Warnf("Pipe write error for %s: %v", clientAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n))
|
||||
p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||
func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
for {
|
||||
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Debugf("Failed to set TCP read deadline: %v", err)
|
||||
}
|
||||
n, err := tcpConn.Read(buf)
|
||||
|
||||
n, err := pipeConn.Read(buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
log.Tracef("tcpToWS goroutine terminating due to context cancellation")
|
||||
log.Tracef("pipeToWS goroutine terminating due to context cancellation")
|
||||
return
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err != io.EOF {
|
||||
log.Errorf("TCP read error: %v", err)
|
||||
log.Debugf("Pipe read error for %s: %v", clientAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write")
|
||||
log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write")
|
||||
return
|
||||
}
|
||||
|
||||
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
|
||||
p.metrics.RecordError(ctx, "websocket_write_error")
|
||||
log.Errorf("WebSocket write error: %v", err)
|
||||
return
|
||||
}
|
||||
if n > 0 {
|
||||
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
|
||||
p.metrics.RecordError(ctx, "websocket_write_error")
|
||||
log.Warnf("WebSocket write error for %s: %v", clientAddr, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n))
|
||||
p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package version
|
||||
|
||||
import "golang.org/x/sys/windows/registry"
|
||||
import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
const (
|
||||
urlWinExe = "https://pkgs.netbird.io/windows/x64"
|
||||
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
|
||||
)
|
||||
|
||||
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
|
||||
@@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
|
||||
// DownloadUrl return with the proper download link
|
||||
func DownloadUrl() string {
|
||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
|
||||
if err == nil {
|
||||
return urlWinExe
|
||||
} else {
|
||||
if err != nil {
|
||||
return downloadURL
|
||||
}
|
||||
|
||||
url := urlWinExe
|
||||
if runtime.GOARCH == "arm64" {
|
||||
url = urlWinExeArm
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user