Compare commits

..

16 Commits

Author SHA1 Message Date
crn4
b8d9386466 sync limit fix 2025-10-19 17:51:58 +02:00
crn4
9914212ce5 simple balancing 2025-10-19 17:51:55 +02:00
Misha Bragin
cd9a867ad0 [client] Delete TURNConfig section from script (#4639) 2025-10-17 19:48:26 +02:00
Maycon Santos
0f9bfeff7c [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 2025-10-17 19:47:11 +02:00
Viktor Liu
f5301230bf [client] Fix status showing P2P without connection (#4661) 2025-10-17 13:31:15 +02:00
Viktor Liu
429d7d6585 [client] Support BROWSER env for login (#4654) 2025-10-17 11:10:16 +02:00
Viktor Liu
3cdb10cde7 [client] Remove rule squashing (#4653) 2025-10-17 11:09:39 +02:00
Zoltan Papp
af95aabb03 Handle the case when the service has already been down and the status recorder is not available (#4652) 2025-10-16 17:15:39 +02:00
Viktor Liu
3abae0bd17 [client] Set default wg port for new profiles (#4651) 2025-10-16 16:16:51 +02:00
Viktor Liu
8252ff41db [client] Add bind activity listener to bypass udp sockets (#4646) 2025-10-16 15:58:29 +02:00
Viktor Liu
277aa2b7cc [client] Fix missing flag values in profiles (#4650) 2025-10-16 15:13:41 +02:00
John Conley
bb37dc89ce [management] feat: Basic PocketID IDP integration (#4529) 2025-10-16 10:46:29 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
56 changed files with 1918 additions and 2689 deletions

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
@@ -356,13 +357,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil {
if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
}
}
}
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return ""
}
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"

View File

@@ -23,4 +23,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -150,6 +150,11 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *WGTunDevice) GetICEBind() EndpointManager {
return t.iceBind
}
func routesToString(routes []string) string {
return strings.Join(routes, ";")
}

View File

@@ -154,3 +154,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -144,3 +144,8 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -179,3 +179,8 @@ func (t *TunKernelDevice) assignAddr() error {
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns nil for kernel mode devices
func (t *TunKernelDevice) GetICEBind() EndpointManager {
return nil
}

View File

@@ -21,6 +21,7 @@ type Bind interface {
conn.Bind
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
ActivityRecorder() *bind.ActivityRecorder
EndpointManager
}
type TunNetstackDevice struct {
@@ -155,3 +156,8 @@ func (t *TunNetstackDevice) Device() *device.Device {
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}
// GetICEBind returns the bind instance
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
return t.bind
}

View File

@@ -146,3 +146,8 @@ func (t *USPDevice) assignAddr() error {
func (t *USPDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *USPDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -185,3 +185,8 @@ func (t *TunDevice) assignAddr() error {
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}
// GetICEBind returns the ICEBind instance
func (t *TunDevice) GetICEBind() EndpointManager {
return t.iceBind
}

View File

@@ -0,0 +1,13 @@
package device
import (
"net"
"net/netip"
)
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
// Implemented by bind.ICEBind and bind.RelayBindJS.
type EndpointManager interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
}

View File

@@ -21,4 +21,5 @@ type WGTunDevice interface {
FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
GetICEBind() device.EndpointManager
}

View File

@@ -80,6 +80,17 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
}
// GetBind returns the EndpointManager userspace bind mode.
func (w *WGIface) GetBind() device.EndpointManager {
w.mu.Lock()
defer w.mu.Unlock()
if w.tun == nil {
return nil
}
return w.tun.GetICEBind()
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
func (w *WGIface) IsUserspaceBind() bool {
return w.userspaceBind

View File

@@ -29,11 +29,6 @@ type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
ips map[string]int
policyID []byte
}
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap)
rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
// if TCP protocol rules not squashed and SSH enabled
// we add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP 22 port).
// If SSH enabled, add default firewall rule which accepts connection to any peer
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)

View File

@@ -188,492 +188,6 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
func TestDefaultManagerSquashRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -14,6 +14,9 @@ type WGIface interface {
}
func (g *BundleGenerator) addWgShow() error {
if g.statusRecorder == nil {
return fmt.Errorf("no status recorder available for wg show")
}
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err

View File

@@ -31,6 +31,7 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err)
}
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var (
searchDomains []string
matchDomains []string

View File

@@ -0,0 +1,82 @@
package activity
import (
"context"
"io"
"net"
"time"
)
// lazyConn detects activity when WireGuard attempts to send packets.
// It does not deliver packets, only signals that activity occurred.
type lazyConn struct {
activityCh chan struct{}
ctx context.Context
cancel context.CancelFunc
}
// newLazyConn creates a new lazyConn for activity detection.
func newLazyConn() *lazyConn {
ctx, cancel := context.WithCancel(context.Background())
return &lazyConn{
activityCh: make(chan struct{}, 1),
ctx: ctx,
cancel: cancel,
}
}
// Read blocks until the connection is closed.
func (c *lazyConn) Read(_ []byte) (n int, err error) {
<-c.ctx.Done()
return 0, io.EOF
}
// Write signals activity detection when ICEBind routes packets to this endpoint.
func (c *lazyConn) Write(b []byte) (n int, err error) {
if c.ctx.Err() != nil {
return 0, io.EOF
}
select {
case c.activityCh <- struct{}{}:
default:
}
return len(b), nil
}
// ActivityChan returns the channel that signals when activity is detected.
func (c *lazyConn) ActivityChan() <-chan struct{} {
return c.activityCh
}
// Close closes the connection.
func (c *lazyConn) Close() error {
c.cancel()
return nil
}
// LocalAddr returns the local address.
func (c *lazyConn) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// RemoteAddr returns the remote address.
func (c *lazyConn) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
}
// SetDeadline sets the read and write deadlines.
func (c *lazyConn) SetDeadline(_ time.Time) error {
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
return nil
}

View File

@@ -0,0 +1,127 @@
package activity
import (
"fmt"
"net"
"net/netip"
"sync"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
type bindProvider interface {
GetBind() device.EndpointManager
}
const (
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
// The actual routing is done via fakeIP in ICEBind, not by this port.
lazyBindPort = 17473
)
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
type BindListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
done sync.WaitGroup
lazyConn *lazyConn
bind device.EndpointManager
fakeIP netip.Addr
}
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
if err != nil {
return nil, fmt.Errorf("derive fake IP: %w", err)
}
d := &BindListener{
wgIface: wgIface,
peerCfg: cfg,
bind: bind,
fakeIP: fakeIP,
}
if err := d.setupLazyConn(); err != nil {
return nil, fmt.Errorf("setup lazy connection: %v", err)
}
d.done.Add(1)
return d, nil
}
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
if len(allowedIPs) == 0 {
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
}
ourNetwork := wgIface.Address().Network
var peerIP netip.Addr
for _, allowedIP := range allowedIPs {
ip := allowedIP.Addr()
if !ip.Is4() {
continue
}
if ourNetwork.Contains(ip) {
peerIP = ip
break
}
}
if !peerIP.IsValid() {
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
}
octets := peerIP.As4()
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
return fakeIP, nil
}
func (d *BindListener) setupLazyConn() error {
d.lazyConn = newLazyConn()
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
endpoint := &net.UDPAddr{
IP: d.fakeIP.AsSlice(),
Port: lazyBindPort,
}
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
}
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
func (d *BindListener) ReadPackets() {
select {
case <-d.lazyConn.ActivityChan():
d.peerCfg.Log.Infof("activity detected via LazyConn")
case <-d.lazyConn.ctx.Done():
d.peerCfg.Log.Infof("exit from activity listener")
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
if err := d.lazyConn.Close(); err != nil {
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
}
d.done.Wait()
}

View File

@@ -0,0 +1,291 @@
package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
}
func newMockEndpointManager() *mockEndpointManager {
return &mockEndpointManager{
endpoints: make(map[netip.Addr]net.Conn),
}
}
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
m.endpoints[fakeIP] = conn
}
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
delete(m.endpoints, fakeIP)
}
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
return m.endpoints[fakeIP]
}
// MockWGIfaceBind mocks WgInterface with bind support
type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
return true
}
func (m *MockWGIfaceBind) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
_, ok := conn.(*lazyConn)
assert.True(t, ok, "Registered endpoint should be a lazyConn")
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestBindListener_ActivityDetection(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
fakeIP := listener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
}
func TestBindListener_Close(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
fakeIP := listener.fakeIP
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer := &MocPeer{PeerID: "testPeer1"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
err := mgr.MonitorPeerActivity(cfg)
require.NoError(t, err)
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
require.True(t, exists, "Peer listener should be found")
bindListener, ok := listener.(*BindListener)
require.True(t, ok, "Listener should be BindListener, got %T", listener)
fakeIP := bindListener.fakeIP
conn := mockEndpointMgr.GetEndpoint(fakeIP)
require.NotNil(t, conn, "Endpoint should be registered")
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
peer1 := &MocPeer{PeerID: "testPeer1"}
peer2 := &MocPeer{PeerID: "testPeer2"}
mgr := NewManager(mockIface)
defer mgr.Close()
cfg1 := lazyconn.PeerConfig{
PublicKey: peer1.PeerID,
PeerConnID: peer1.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
cfg2 := lazyconn.PeerConfig{
PublicKey: peer2.PeerID,
PeerConnID: peer2.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
Log: log.WithField("peer", "testPeer2"),
}
err := mgr.MonitorPeerActivity(cfg1)
require.NoError(t, err)
err = mgr.MonitorPeerActivity(cfg2)
require.NoError(t, err)
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
require.True(t, exists, "Peer1 listener should be found")
bindListener1 := listener1.(*BindListener)
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
require.True(t, exists, "Peer2 listener should be found")
bindListener2 := listener2.(*BindListener)
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
_, err = conn1.Write([]byte{0x01})
require.NoError(t, err)
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
_, err = conn2.Write([]byte{0x02})
require.NoError(t, err)
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}
}
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
}

View File

@@ -1,41 +0,0 @@
package activity
import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestNewListener(t *testing.T) {
peer := &MocPeer{
PeerID: "examplePublicKey1",
}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
Log: log.WithField("peer", "examplePublicKey1"),
}
l, err := NewListener(MocWGIface{}, cfg)
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
chanClosed := make(chan struct{})
go func() {
defer close(chanClosed)
l.ReadPackets()
}()
time.Sleep(1 * time.Second)
l.Close()
select {
case <-chanClosed:
case <-time.After(time.Second):
}
}

View File

@@ -11,26 +11,27 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct {
// UDPListener uses UDP sockets for activity detection in kernel mode.
type UDPListener struct {
wgIface WgInterface
peerCfg lazyconn.PeerConfig
conn *net.UDPConn
endpoint *net.UDPAddr
done sync.Mutex
isClosed atomic.Bool // use to avoid error log when closing the listener
isClosed atomic.Bool
}
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{
// NewUDPListener creates a listener that detects activity via UDP socket reads.
func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
d := &UDPListener{
wgIface: wgIface,
peerCfg: cfg,
}
conn, err := d.newConn()
if err != nil {
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
return nil, fmt.Errorf("create UDP connection: %v", err)
}
d.conn = conn
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
@@ -38,12 +39,14 @@ func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error
if err := d.createEndpoint(); err != nil {
return nil, err
}
d.done.Lock()
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
return d, nil
}
func (d *Listener) ReadPackets() {
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
func (d *UDPListener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil {
@@ -64,15 +67,17 @@ func (d *Listener) ReadPackets() {
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil {
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
// Ignore close error as it may return "use of closed network connection" if already closed.
_ = d.conn.Close()
d.done.Unlock()
}
func (d *Listener) Close() {
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true)
@@ -82,16 +87,12 @@ func (d *Listener) Close() {
d.done.Lock()
}
func (d *Listener) removeEndpoint() error {
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
}
func (d *Listener) createEndpoint() error {
func (d *UDPListener) createEndpoint() error {
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
}
func (d *Listener) newConn() (*net.UDPConn, error) {
func (d *UDPListener) newConn() (*net.UDPConn, error) {
addr := &net.UDPAddr{
Port: 0,
IP: listenIP,

View File

@@ -0,0 +1,110 @@
package activity
import (
"net"
"net/netip"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestUDPListener_Creation(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
require.NotNil(t, listener.conn)
require.NotNil(t, listener.endpoint)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
}
func TestUDPListener_ActivityDetection(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
activityDetected := make(chan struct{})
go func() {
listener.ReadPackets()
close(activityDetected)
}()
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
require.NoError(t, err)
select {
case <-activityDetected:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity detection")
}
}
func TestUDPListener_Close(t *testing.T) {
mockIface := &MocWGIface{}
peer := &MocPeer{PeerID: "testPeer1"}
cfg := lazyconn.PeerConfig{
PublicKey: peer.PeerID,
PeerConnID: peer.ConnID(),
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
Log: log.WithField("peer", "testPeer1"),
}
listener, err := NewUDPListener(mockIface, cfg)
require.NoError(t, err)
readPacketsDone := make(chan struct{})
go func() {
listener.ReadPackets()
close(readPacketsDone)
}()
listener.Close()
select {
case <-readPacketsDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ReadPackets to exit after Close")
}
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
}

View File

@@ -1,21 +1,32 @@
package activity
import (
"errors"
"net"
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
// listener defines the contract for activity detection listeners.
type listener interface {
ReadPackets()
Close()
}
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
}
type Manager struct {
@@ -23,7 +34,7 @@ type Manager struct {
wgIface WgInterface
peers map[peerid.ConnID]*Listener
peers map[peerid.ConnID]listener
done chan struct{}
mu sync.Mutex
@@ -33,7 +44,7 @@ func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]*Listener),
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
}
return m
@@ -48,16 +59,38 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
return nil
}
listener, err := NewListener(m.wgIface, peerCfg)
listener, err := m.createListener(peerCfg)
if err != nil {
return err
}
m.peers[peerCfg.PeerConnID] = listener
m.peers[peerCfg.PeerConnID] = listener
go m.waitForTraffic(listener, peerCfg.PeerConnID)
return nil
}
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
if !m.wgIface.IsUserspaceBind() {
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is only used on Windows and JS platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
}
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
}
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -82,8 +115,8 @@ func (m *Manager) Close() {
}
}
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
listener.ReadPackets()
func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
l.ReadPackets()
m.mu.Lock()
if _, ok := m.peers[peerConnID]; !ok {

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
@@ -30,16 +31,26 @@ func (m MocWGIface) RemovePeer(string) error {
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
// Add this method to the Manager struct
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
func (m MocWGIface) IsUserspaceBind() bool {
return false
}
func (m MocWGIface) Address() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
}
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
defer m.mu.Unlock()
listener, exists := m.peers[peerConnID]
return listener, exists
l, exists := m.peers[peerConnID]
return l, exists
}
func TestManager_MonitorPeerActivity(t *testing.T) {
@@ -65,7 +76,12 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
t.Fatalf("peer listener not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
// Get the UDP listener's address for triggering
udpListener, ok := listener.(*UDPListener)
if !ok {
t.Fatalf("expected UDPListener")
}
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -97,7 +113,9 @@ func TestManager_RemovePeerActivity(t *testing.T) {
t.Fatalf("failed to monitor peer activity: %v", err)
}
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
udpListener, _ := listener.(*UDPListener)
addr := udpListener.conn.LocalAddr().String()
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
@@ -147,7 +165,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer1 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
udpListener1, _ := listener.(*UDPListener)
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}
@@ -156,7 +175,8 @@ func TestManager_MultiPeerActivity(t *testing.T) {
t.Fatalf("peer listener for peer2 not found")
}
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
udpListener2, _ := listener.(*UDPListener)
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
t.Fatalf("failed to trigger activity: %v", err)
}

View File

@@ -7,6 +7,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/monotime"
)
@@ -14,5 +15,6 @@ type WGIface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
}

View File

@@ -195,6 +195,7 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
WgPort: iface.DefaultWgPort,
}
if _, err := config.apply(input); err != nil {

View File

@@ -5,11 +5,14 @@ import (
"errors"
"os"
"path/filepath"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/util"
)
@@ -141,6 +144,95 @@ func TestHiddenPreSharedKey(t *testing.T) {
}
}
func TestNewProfileDefaults(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
})
require.NoError(t, err, "should create new config")
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
}
}
func TestWireguardPortZeroExplicit(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
// Create a new profile with explicit port 0 (random port)
explicitZero := 0
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: &explicitZero,
})
require.NoError(t, err, "should create config with explicit port 0")
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
// Verify it persists
readConfig, err := GetConfig(configPath)
require.NoError(t, err)
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
}
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
tests := []struct {
name string
wireguardPort *int
expectedPort int
description string
}{
{
name: "no port specified uses default",
wireguardPort: nil,
expectedPort: iface.DefaultWgPort,
description: "When user doesn't specify port, default to 51820",
},
{
name: "explicit zero for random port",
wireguardPort: func() *int { v := 0; return &v }(),
expectedPort: 0,
description: "When user explicitly sets 0, use 0 for random port",
},
{
name: "explicit custom port",
wireguardPort: func() *int { v := 52000; return &v }(),
expectedPort: 52000,
description: "When user sets custom port, use that port",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
configPath := filepath.Join(tempDir, "config.json")
config, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: configPath,
WireguardPort: tt.wireguardPort,
})
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -353,6 +353,13 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.CustomDNSAddress = []byte{}
}
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
if msg.DnsRouteInterval != nil {
interval := msg.DnsRouteInterval.AsDuration()
config.DNSRouteInterval = &interval
}
config.RosenpassEnabled = msg.RosenpassEnabled
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect

View File

@@ -0,0 +1,298 @@
package server
import (
"context"
"os/user"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
)
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
func TestSetConfig_AllFieldsSaved(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origDefaultConfigPath := profilemanager.DefaultConfigPath
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.ConfigDirOverride = tempDir
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.DefaultConfigPath = origDefaultConfigPath
profilemanager.ConfigDirOverride = ""
})
currUser, err := user.Current()
require.NoError(t, err)
profName := "test-profile"
ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: "https://api.netbird.io:443",
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
})
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false)
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
disableAutoConnect := true
networkMonitor := true
disableClientRoutes := true
disableServerRoutes := true
disableDNS := true
disableFirewall := true
blockLANAccess := true
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
mtu := int64(1280)
req := &proto.SetConfigRequest{
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
}
_, err = s.SetConfig(ctx, req)
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()
require.NoError(t, err)
cfg, err := profilemanager.GetConfig(cfgPath)
require.NoError(t, err)
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
require.NotNil(t, cfg.NetworkMonitor)
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
require.Equal(t, disableDNS, cfg.DisableDNS)
require.Equal(t, disableFirewall, cfg.DisableFirewall)
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
require.NotNil(t, cfg.DisableNotifications)
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
// IFaceBlackList contains defaults + extras
require.Contains(t, cfg.IFaceBlackList, "eth1")
require.Contains(t, cfg.IFaceBlackList, "eth2")
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
require.Equal(t, uint16(mtu), cfg.MTU)
verifyAllFieldsCovered(t, req)
}
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
// If a new field is added to SetConfigRequest, this function will fail the test,
// forcing the developer to update both the SetConfig handler and this test.
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
t.Helper()
metadataFields := map[string]bool{
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
}
expectedFields := map[string]bool{
"ManagementUrl": true,
"AdminURL": true,
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
"DisableAutoConnect": true,
"NetworkMonitor": true,
"DisableClientRoutes": true,
"DisableServerRoutes": true,
"DisableDns": true,
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"NatExternalIPs": true,
"CustomDNSAddress": true,
"ExtraIFaceBlacklist": true,
"DnsLabels": true,
"DnsRouteInterval": true,
"Mtu": true,
}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unexpectedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
if metadataFields[fieldName] {
continue
}
if !expectedFields[fieldName] {
unexpectedFields = append(unexpectedFields, fieldName)
}
}
if len(unexpectedFields) > 0 {
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
}
}
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// Map of CLI flag names to their corresponding SetConfigRequest field names.
// This map must be updated when adding new config-related CLI flags.
flagToField := map[string]string{
"management-url": "ManagementUrl",
"admin-url": "AdminURL",
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",
"disable-auto-connect": "DisableAutoConnect",
"network-monitor": "NetworkMonitor",
"disable-client-routes": "DisableClientRoutes",
"disable-server-routes": "DisableServerRoutes",
"disable-dns": "DisableDns",
"disable-firewall": "DisableFirewall",
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
"extra-dns-labels": "DnsLabels",
"dns-router-interval": "DnsRouteInterval",
"mtu": "Mtu",
}
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
fieldsWithoutCLIFlags := map[string]bool{
"DisableNotifications": true, // Only settable via UI
}
// Get all SetConfigRequest fields to verify our map is complete.
req := &proto.SetConfigRequest{}
val := reflect.ValueOf(req).Elem()
typ := val.Type()
var unmappedFields []string
for i := 0; i < val.NumField(); i++ {
field := typ.Field(i)
fieldName := field.Name
// Skip protobuf internal fields and metadata fields.
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
continue
}
if fieldName == "Username" || fieldName == "ProfileName" {
continue
}
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
continue
}
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
mappedToCLI := false
for _, mappedField := range flagToField {
if mappedField == fieldName {
mappedToCLI = true
break
}
}
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
if !mappedToCLI && !hasNoCLIFlag {
unmappedFields = append(unmappedFields, fieldName)
}
}
if len(unmappedFields) > 0 {
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
}
t.Log("All SetConfigRequest fields are properly documented")
}

View File

@@ -205,15 +205,18 @@ func mapPeers(
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := "P2P"
connType := "-"
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if pbPeerState.Relayed {
connType = "Relayed"
if isPeerConnected {
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
}
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {

View File

@@ -31,7 +31,6 @@ import (
"fyne.io/systray"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@@ -633,7 +632,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
}
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := open.Run(loginResp.VerificationURIComplete)
err := openURL(loginResp.VerificationURIComplete)
if err != nil {
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
@@ -1487,6 +1486,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
}
func openURL(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
var err error
switch runtime.GOOS {
case "windows":

View File

@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -6,11 +6,13 @@ import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
@@ -19,18 +21,34 @@ const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -3,6 +3,7 @@
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
@@ -11,11 +12,17 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version")
p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination
}
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed")
p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
tlsConfig := p.getTLSConfigWithValidation(conn)
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)

2
go.mod
View File

@@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -49,6 +49,7 @@ services:
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.routers.netbird-signal.service=netbird-signal
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -682,17 +682,6 @@ renderManagementJson() {
"URI": "stun:$NETBIRD_DOMAIN:3478"
}
],
"TURNConfig": {
"Turns": [
{
"Proto": "udp",
"URI": "turn:$NETBIRD_DOMAIN:3478",
"Username": "$TURN_USER",
"Password": "$TURN_PASSWORD"
}
],
"TimeBasedCredentials": false
},
"Relay": {
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
"CredentialsTTL": "24h",

View File

@@ -7,8 +7,10 @@ import (
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -44,6 +46,9 @@ import (
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
defaultSyncLim = 1000
)
// GRPCServer an instance of a Management gRPC API server
@@ -63,6 +68,9 @@ type GRPCServer struct {
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
syncSem atomic.Int32
syncLim int32
}
// NewServer creates a new Management server
@@ -96,6 +104,16 @@ func NewServer(
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
syncLim := int32(defaultSyncLim)
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
} else {
syncLim = int32(syncLimParsed)
}
}
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@@ -110,6 +128,8 @@ func NewServer(
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
syncLim: syncLim,
}, nil
}
@@ -151,6 +171,11 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
reqStart := time.Now()
ctx := srv.Context()
@@ -158,6 +183,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
s.syncSem.Add(-1)
return err
}
realIP := getRealIP(ctx)
@@ -172,6 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
@@ -196,8 +223,10 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
s.syncSem.Add(-1)
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
s.syncSem.Add(-1)
return err
}
@@ -213,12 +242,14 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return mapError(ctx, err)
}
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return err
}
@@ -237,6 +268,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}

View File

@@ -26,9 +26,11 @@ type mockHTTPClient struct {
}
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
if req.Body != nil {
body, err := io.ReadAll(req.Body)
if err == nil {
c.reqBody = string(body)
}
}
return &http.Response{
StatusCode: c.code,

View File

@@ -201,6 +201,12 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
APIToken: config.ExtraConfig["ApiToken"],
}
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
case "pocketid":
pocketidConfig := PocketIdClientConfig{
APIToken: config.ExtraConfig["ApiToken"],
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
}
return NewPocketIdManager(pocketidConfig, appMetrics)
default:
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
}

View File

@@ -0,0 +1,384 @@
package idp
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"slices"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
)
type PocketIdManager struct {
managementEndpoint string
apiToken string
httpClient ManagerHTTPClient
credentials ManagerCredentials
helper ManagerHelper
appMetrics telemetry.AppMetrics
}
type pocketIdCustomClaimDto struct {
Key string `json:"key"`
Value string `json:"value"`
}
type pocketIdUserDto struct {
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
Disabled bool `json:"disabled"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
ID string `json:"id"`
IsAdmin bool `json:"isAdmin"`
LastName string `json:"lastName"`
LdapID string `json:"ldapId"`
Locale string `json:"locale"`
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
Username string `json:"username"`
}
type pocketIdUserCreateDto struct {
Disabled bool `json:"disabled,omitempty"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
FirstName string `json:"firstName"`
IsAdmin bool `json:"isAdmin,omitempty"`
LastName string `json:"lastName,omitempty"`
Locale string `json:"locale,omitempty"`
Username string `json:"username"`
}
type pocketIdPaginatedUserDto struct {
Data []pocketIdUserDto `json:"data"`
Pagination pocketIdPaginationDto `json:"pagination"`
}
type pocketIdPaginationDto struct {
CurrentPage int `json:"currentPage"`
ItemsPerPage int `json:"itemsPerPage"`
TotalItems int `json:"totalItems"`
TotalPages int `json:"totalPages"`
}
func (p *pocketIdUserDto) userData() *UserData {
return &UserData{
Email: p.Email,
Name: p.DisplayName,
ID: p.ID,
AppMetadata: AppMetadata{},
}
}
type pocketIdUserGroupDto struct {
CreatedAt string `json:"createdAt"`
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
FriendlyName string `json:"friendlyName"`
ID string `json:"id"`
LdapID string `json:"ldapId"`
Name string `json:"name"`
}
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Transport: httpTransport,
}
helper := JsonParser{}
if config.ManagementEndpoint == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
}
if config.APIToken == "" {
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
}
credentials := &PocketIdCredentials{
clientConfig: config,
httpClient: httpClient,
helper: helper,
appMetrics: appMetrics,
}
return &PocketIdManager{
managementEndpoint: config.ManagementEndpoint,
apiToken: config.APIToken,
httpClient: httpClient,
credentials: credentials,
helper: helper,
appMetrics: appMetrics,
}, nil
}
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
if !slices.Contains(MethodsWithBody, method) && body != "" {
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
}
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
if query != nil {
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
}
var req *http.Request
var err error
if body != "" {
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
} else {
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
}
if err != nil {
return nil, err
}
req.Header.Add("X-API-KEY", p.apiToken)
if body != "" {
req.Header.Add("content-type", "application/json")
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
}
resp, err := p.httpClient.Do(req)
if err != nil {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestError()
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountRequestStatusError()
}
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
// getAllUsersPaginated fetches all users from PocketID API using pagination
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
var allUsers []pocketIdUserDto
currentPage := 1
for {
params := url.Values{}
// Copy existing search parameters
for key, values := range searchParams {
params[key] = values
}
params.Set("pagination[limit]", "100")
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
var profiles pocketIdPaginatedUserDto
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
allUsers = append(allUsers, profiles.Data...)
// Check if we've reached the last page
if currentPage >= profiles.Pagination.TotalPages {
break
}
currentPage++
}
return allUsers, nil
}
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
return nil
}
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserDataByID()
}
var user pocketIdUserDto
err = p.helper.Unmarshal(body, &user)
if err != nil {
return nil, err
}
userData := user.userData()
userData.AppMetadata = appMetadata
return userData, nil
}
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAccount()
}
users := make([]*UserData, 0)
for _, profile := range allUsers {
userData := profile.userData()
userData.AppMetadata.WTAccountID = accountId
users = append(users, userData)
}
return users, nil
}
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
// Get all users using pagination
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetAllAccounts()
}
indexedUsers := make(map[string][]*UserData)
for _, profile := range allUsers {
userData := profile.userData()
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
}
return indexedUsers, nil
}
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.Split(name, " ")
createUser := pocketIdUserCreateDto{
Disabled: false,
DisplayName: name,
Email: email,
FirstName: firstLast[0],
LastName: firstLast[1],
Username: firstLast[0] + "." + firstLast[1],
}
payload, err := p.helper.Marshal(createUser)
if err != nil {
return nil, err
}
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
if err != nil {
return nil, err
}
var newUser pocketIdUserDto
err = p.helper.Unmarshal(body, &newUser)
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountCreateUser()
}
var pending bool = true
ret := &UserData{
Email: email,
Name: name,
ID: newUser.ID,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pending,
WTInvitedBy: invitedByEmail,
},
}
return ret, nil
}
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
params := url.Values{
// This value a
"search": []string{email},
}
body, err := p.request(ctx, http.MethodGet, "users", &params, "")
if err != nil {
return nil, err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountGetUserByEmail()
}
var profiles struct{ data []pocketIdUserDto }
err = p.helper.Unmarshal(body, &profiles)
if err != nil {
return nil, err
}
users := make([]*UserData, 0)
for _, profile := range profiles.data {
users = append(users, profile.userData())
}
return users, nil
}
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
if err != nil {
return err
}
return nil
}
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
if err != nil {
return err
}
if p.appMetrics != nil {
p.appMetrics.IDPMetrics().CountDeleteUser()
}
return nil
}
var _ Manager = (*PocketIdManager)(nil)
type PocketIdClientConfig struct {
APIToken string
ManagementEndpoint string
}
type PocketIdCredentials struct {
clientConfig PocketIdClientConfig
helper ManagerHelper
httpClient ManagerHTTPClient
appMetrics telemetry.AppMetrics
}
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
return JWTToken{}, nil
}

View File

@@ -0,0 +1,138 @@
package idp
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewPocketIdManager(t *testing.T) {
type test struct {
name string
inputConfig PocketIdClientConfig
assertErrFunc require.ErrorAssertionFunc
assertErrFuncMessage string
}
defaultTestConfig := PocketIdClientConfig{
APIToken: "api_token",
ManagementEndpoint: "http://localhost",
}
tests := []test{
{
name: "Good Configuration",
inputConfig: defaultTestConfig,
assertErrFunc: require.NoError,
assertErrFuncMessage: "shouldn't return error",
},
{
name: "Missing ManagementEndpoint",
inputConfig: PocketIdClientConfig{
APIToken: defaultTestConfig.APIToken,
ManagementEndpoint: "",
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
{
name: "Missing APIToken",
inputConfig: PocketIdClientConfig{
APIToken: "",
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
},
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
})
}
}
func TestPocketID_GetUserDataByID(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
md := AppMetadata{WTAccountID: "acc1"}
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
require.NoError(t, err)
assert.Equal(t, "u1", got.ID)
assert.Equal(t, "user1@example.com", got.Email)
assert.Equal(t, "User One", got.Name)
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
}
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
// Single page response with two users
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
users, err := mgr.GetAccount(context.Background(), "accX")
require.NoError(t, err)
require.Len(t, users, 2)
assert.Equal(t, "u1", users[0].ID)
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
assert.Equal(t, "u2", users[1].ID)
}
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
accounts, err := mgr.GetAllAccounts(context.Background())
require.NoError(t, err)
require.Len(t, accounts[UnsetAccountID], 2)
}
func TestPocketID_CreateUser(t *testing.T) {
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
require.NoError(t, err)
assert.Equal(t, "newid", ud.ID)
assert.Equal(t, "new@example.com", ud.Email)
assert.Equal(t, "New User", ud.Name)
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
assert.True(t, *ud.AppMetadata.WTPendingInvite)
}
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
}
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
// Same mock for both calls; returns OK with empty JSON
client := &mockHTTPClient{code: 200, resBody: `{}`}
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
require.NoError(t, err)
mgr.httpClient = client
err = mgr.InviteUserByID(context.Background(), "u1")
require.NoError(t, err)
err = mgr.DeleteUser(context.Background(), "u1")
require.NoError(t, err)
}

View File

@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}

View File

@@ -3,16 +3,16 @@ package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error

View File

@@ -578,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {

View File

@@ -1,381 +0,0 @@
# Signal Server Load Test
Load testing tool for the NetBird signal server.
## Features
- **Rate-based peer pair creation**: Spawn peer pairs at configurable rates (e.g., 10, 20 pairs/sec)
- **Two exchange modes**:
- **Single message**: Each pair exchanges one message for validation
- **Continuous exchange**: Pairs continuously exchange messages for a specified duration (e.g., 30 seconds, 10 minutes)
- **TLS/HTTPS support**: Connect to TLS-enabled signal servers with optional certificate verification
- **Automatic reconnection**: Optional automatic reconnection with exponential backoff on connection loss
- **Configurable message interval**: Control message send rate in continuous mode
- **Message exchange validation**: Validates encrypted body size > 0
- **Comprehensive metrics**: Tracks throughput, success/failure rates, latency statistics, and reconnection counts
- **Local server testing**: Tests include embedded signal server for easy development
- **Worker pool pattern**: Efficient concurrent execution
- **Graceful shutdown**: Context-based cancellation
## Usage
### Standalone Binary
Build and run the load test as a standalone binary:
```bash
# Build the binary
cd signal/loadtest/cmd/signal-loadtest
go build -o signal-loadtest
# Single message exchange
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 10 \
-total-pairs 100 \
-message-size 100
# Continuous exchange for 30 seconds
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 10 \
-total-pairs 20 \
-message-size 200 \
-exchange-duration 30s \
-message-interval 200ms
# Long-running test (10 minutes)
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 20 \
-total-pairs 50 \
-message-size 500 \
-exchange-duration 10m \
-message-interval 100ms \
-test-duration 15m \
-log-level debug
# TLS server with valid certificate
./signal-loadtest \
-server https://signal.example.com:443 \
-pairs-per-sec 10 \
-total-pairs 50 \
-message-size 100
# TLS server with self-signed certificate
./signal-loadtest \
-server https://localhost:443 \
-pairs-per-sec 5 \
-total-pairs 10 \
-insecure-skip-verify \
-log-level debug
# High load test with custom worker pool
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 100 \
-total-pairs 1000 \
-worker-pool-size 500 \
-channel-buffer-size 1000 \
-exchange-duration 60s \
-log-level info
# Progress reporting - report every 5000 messages
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 50 \
-total-pairs 100 \
-exchange-duration 5m \
-report-interval 5000 \
-log-level info
# With automatic reconnection
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 10 \
-total-pairs 50 \
-exchange-duration 5m \
-enable-reconnect \
-initial-retry-delay 100ms \
-max-reconnect-delay 30s \
-log-level debug
# Show help
./signal-loadtest -h
```
**Graceful Shutdown:**
The load test supports graceful shutdown via Ctrl+C (SIGINT/SIGTERM):
- Press Ctrl+C to interrupt the test at any time
- All active clients will be closed gracefully
- A final aggregated report will be printed showing metrics up to the point of interruption
- Shutdown timeout: 5 seconds (after which the process will force exit)
**Available Flags:**
- `-server`: Signal server URL (http:// or https://) (default: `http://localhost:10000`)
- `-pairs-per-sec`: Peer pairs created per second (default: 10)
- `-total-pairs`: Total number of peer pairs (default: 100)
- `-message-size`: Message size in bytes (default: 100)
- `-test-duration`: Maximum test duration, 0 = unlimited (default: 0)
- `-exchange-duration`: Continuous exchange duration per pair, 0 = single message (default: 0)
- `-message-interval`: Interval between messages in continuous mode (default: 100ms)
- `-worker-pool-size`: Number of concurrent workers, 0 = auto (pairs-per-sec × 2) (default: 0)
- `-channel-buffer-size`: Work queue buffer size, 0 = auto (pairs-per-sec × 4) (default: 0)
- `-report-interval`: Report progress every N messages, 0 = no periodic reports (default: 10000)
- `-enable-reconnect`: Enable automatic reconnection on connection loss (default: false)
- `-initial-retry-delay`: Initial delay before first reconnection attempt (default: 100ms)
- `-max-reconnect-delay`: Maximum delay between reconnection attempts (default: 30s)
- `-insecure-skip-verify`: Skip TLS certificate verification for self-signed certificates (default: false)
- `-log-level`: Log level: trace, debug, info, warn, error (default: info)
### Running Tests
```bash
# Run all tests (includes load tests)
go test -v -timeout 2m
# Run specific single-message load tests
go test -v -run TestLoadTest_10PairsPerSecond -timeout 40s
go test -v -run TestLoadTest_20PairsPerSecond -timeout 40s
go test -v -run TestLoadTest_SmallBurst -timeout 30s
# Run continuous exchange tests
go test -v -run TestLoadTest_ContinuousExchange_ShortBurst -timeout 30s
go test -v -run TestLoadTest_ContinuousExchange_30Seconds -timeout 2m
go test -v -run TestLoadTest_ContinuousExchange_10Minutes -timeout 15m
# Skip long-running tests in quick runs
go test -short
```
### Programmatic Usage
#### Single Message Exchange
```go
package main
import (
"github.com/netbirdio/netbird/signal/loadtest"
"time"
)
func main() {
config := loadtest.LoadTestConfig{
ServerURL: "http://localhost:10000",
PairsPerSecond: 10,
TotalPairs: 100,
MessageSize: 100,
TestDuration: 30 * time.Second,
}
lt := loadtest.NewLoadTest(config)
if err := lt.Run(); err != nil {
panic(err)
}
metrics := lt.GetMetrics()
metrics.PrintReport()
}
```
#### Continuous Message Exchange
```go
package main
import (
"github.com/netbirdio/netbird/signal/loadtest"
"time"
)
func main() {
config := loadtest.LoadTestConfig{
ServerURL: "http://localhost:10000",
PairsPerSecond: 10,
TotalPairs: 20,
MessageSize: 200,
ExchangeDuration: 10 * time.Minute, // Each pair exchanges messages for 10 minutes
MessageInterval: 200 * time.Millisecond, // Send message every 200ms
TestDuration: 15 * time.Minute, // Overall test timeout
}
lt := loadtest.NewLoadTest(config)
if err := lt.Run(); err != nil {
panic(err)
}
metrics := lt.GetMetrics()
metrics.PrintReport()
}
```
## Configuration Options
- **ServerURL**: Signal server URL (e.g., `http://localhost:10000` or `https://signal.example.com:443`)
- **PairsPerSecond**: Rate at which peer pairs are created (e.g., 10, 20)
- **TotalPairs**: Total number of peer pairs to create
- **MessageSize**: Size of test message payload in bytes
- **TestDuration**: Maximum test duration (optional, 0 = no limit)
- **ExchangeDuration**: Duration for continuous message exchange per pair (0 = single message)
- **MessageInterval**: Interval between messages in continuous mode (default: 100ms)
- **WorkerPoolSize**: Number of concurrent worker goroutines (0 = auto: pairs-per-sec × 2)
- **ChannelBufferSize**: Work queue buffer size (0 = auto: pairs-per-sec × 4)
- **ReportInterval**: Report progress every N messages (0 = no periodic reports, default: 10000)
- **EnableReconnect**: Enable automatic reconnection on connection loss (default: false)
- **InitialRetryDelay**: Initial delay before first reconnection attempt (default: 100ms)
- **MaxReconnectDelay**: Maximum delay between reconnection attempts (default: 30s)
- **InsecureSkipVerify**: Skip TLS certificate verification (for self-signed certificates)
- **RampUpDuration**: Gradual ramp-up period (not yet implemented)
### Reconnection Handling
The load test supports automatic reconnection on connection loss:
- **Disabled by default**: Connections will fail on any network interruption
- **When enabled**: Clients automatically reconnect with exponential backoff
- **Exponential backoff**: Starts at `InitialRetryDelay`, doubles on each failure, caps at `MaxReconnectDelay`
- **Transparent reconnection**: Message exchange continues after successful reconnection
- **Metrics tracking**: Total reconnection count is reported
**Use cases:**
- Testing resilience to network interruptions
- Validating server restart behavior
- Simulating flaky network conditions
- Long-running stability tests
**Example with reconnection:**
```go
config := loadtest.LoadTestConfig{
ServerURL: "http://localhost:10000",
PairsPerSecond: 10,
TotalPairs: 20,
ExchangeDuration: 10 * time.Minute,
EnableReconnect: true,
InitialRetryDelay: 100 * time.Millisecond,
MaxReconnectDelay: 30 * time.Second,
}
```
### Performance Tuning
When running high-load tests, you may need to adjust the worker pool and buffer sizes:
- **Default sizing**: Auto-configured based on `PairsPerSecond`
- Worker pool: `PairsPerSecond × 2`
- Channel buffer: `PairsPerSecond × 4`
- **For continuous exchange**: Increase worker pool size (e.g., `PairsPerSecond × 5`)
- **For high pair rates** (>50/sec): Increase both worker pool and buffer proportionally
- **Signs you need more workers**: Log warnings about "Worker pool saturated"
Example for 100 pairs/sec with continuous exchange:
```go
config := LoadTestConfig{
PairsPerSecond: 100,
WorkerPoolSize: 500, // 5x pairs/sec
ChannelBufferSize: 1000, // 10x pairs/sec
}
```
## Metrics
The load test collects and reports:
- **Total Pairs Sent**: Number of peer pairs attempted
- **Successful Exchanges**: Completed message exchanges
- **Failed Exchanges**: Failed message exchanges
- **Total Messages Exchanged**: Count of successfully exchanged messages
- **Total Errors**: Cumulative error count
- **Total Reconnections**: Number of automatic reconnections (if enabled)
- **Throughput**: Pairs per second (actual)
- **Latency Statistics**: Min, Max, Avg message exchange latency
## Graceful Shutdown Example
You can interrupt a long-running test at any time with Ctrl+C:
```
./signal-loadtest -server http://localhost:10000 -pairs-per-sec 10 -total-pairs 100 -exchange-duration 10m
# Press Ctrl+C after some time...
^C
WARN[0045]
Received interrupt signal, shutting down gracefully...
=== Load Test Report ===
Test Duration: 45.234s
Total Pairs Sent: 75
Successful Exchanges: 75
Failed Exchanges: 0
Total Messages Exchanged: 22500
Total Errors: 0
Throughput: 1.66 pairs/sec
...
========================
```
## Test Results
Example output from a 20 pairs/sec test:
```
=== Load Test Report ===
Test Duration: 5.055249917s
Total Pairs Sent: 100
Successful Exchanges: 100
Failed Exchanges: 0
Total Messages Exchanged: 100
Total Errors: 0
Throughput: 19.78 pairs/sec
Latency Statistics:
Min: 170.375µs
Max: 5.176916ms
Avg: 441.566µs
========================
```
## Architecture
### Client (`client.go`)
- Manages gRPC connection to signal server
- Establishes bidirectional stream for receiving messages
- Sends messages via `Send` RPC method
- Handles message reception asynchronously
### Load Test Engine (`rate_loadtest.go`)
- Worker pool pattern for concurrent peer pairs
- Rate-limited pair creation using ticker
- Atomic counters for thread-safe metrics collection
- Graceful shutdown on context cancellation
### Test Suite
- `loadtest_test.go`: Single pair validation test
- `rate_loadtest_test.go`: Multiple rate-based load tests and benchmarks
## Implementation Details
### Message Flow
1. Create sender and receiver clients with unique IDs
2. Both clients connect to signal server via bidirectional stream
3. Sender sends encrypted message using `Send` RPC
4. Signal server forwards message to receiver's stream
5. Receiver reads message from stream
6. Validate encrypted body size > 0
7. Record latency and success metrics
### Concurrency
- Worker pool size = `PairsPerSecond`
- Each worker handles multiple peer pairs sequentially
- Atomic operations for metrics to avoid lock contention
- Channel-based work distribution
## Future Enhancements
- [x] TLS/HTTPS support for production servers
- [x] Automatic reconnection with exponential backoff
- [ ] Ramp-up period implementation
- [ ] Percentile latency metrics (p50, p95, p99)
- [ ] Connection reuse for multiple messages per pair
- [ ] Support for custom message payloads
- [ ] CSV/JSON metrics export
- [ ] Real-time metrics dashboard

View File

@@ -1,301 +0,0 @@
package loadtest
import (
"context"
"crypto/tls"
"fmt"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/shared/signal/proto"
)
// Client represents a signal client for load testing
type Client struct {
id string
serverURL string
config *ClientConfig
conn *grpc.ClientConn
client proto.SignalExchangeClient
stream proto.SignalExchange_ConnectStreamClient
ctx context.Context
cancel context.CancelFunc
msgChannel chan *proto.EncryptedMessage
mu sync.RWMutex
reconnectCount int64
connected bool
receiverStarted bool
}
// ClientConfig holds optional configuration for the client
type ClientConfig struct {
InsecureSkipVerify bool
EnableReconnect bool
MaxReconnectDelay time.Duration
InitialRetryDelay time.Duration
}
// NewClient creates a new signal client for load testing
func NewClient(serverURL, peerID string) (*Client, error) {
return NewClientWithConfig(serverURL, peerID, nil)
}
// NewClientWithConfig creates a new signal client with custom TLS configuration
func NewClientWithConfig(serverURL, peerID string, config *ClientConfig) (*Client, error) {
if config == nil {
config = &ClientConfig{}
}
// Set default reconnect delays if not specified
if config.EnableReconnect {
if config.InitialRetryDelay == 0 {
config.InitialRetryDelay = 100 * time.Millisecond
}
if config.MaxReconnectDelay == 0 {
config.MaxReconnectDelay = 30 * time.Second
}
}
addr, opts, err := parseServerURL(serverURL, config.InsecureSkipVerify)
if err != nil {
return nil, fmt.Errorf("parse server URL: %w", err)
}
conn, err := grpc.Dial(addr, opts...)
if err != nil {
return nil, fmt.Errorf("dial server: %w", err)
}
client := proto.NewSignalExchangeClient(conn)
ctx, cancel := context.WithCancel(context.Background())
return &Client{
id: peerID,
serverURL: serverURL,
config: config,
conn: conn,
client: client,
ctx: ctx,
cancel: cancel,
msgChannel: make(chan *proto.EncryptedMessage, 10),
connected: false,
}, nil
}
// Connect establishes a stream connection to the signal server
func (c *Client) Connect() error {
md := metadata.New(map[string]string{proto.HeaderId: c.id})
ctx := metadata.NewOutgoingContext(c.ctx, md)
stream, err := c.client.ConnectStream(ctx)
if err != nil {
return fmt.Errorf("connect stream: %w", err)
}
if _, err := stream.Header(); err != nil {
return fmt.Errorf("receive header: %w", err)
}
c.mu.Lock()
c.stream = stream
c.connected = true
if !c.receiverStarted {
c.receiverStarted = true
c.mu.Unlock()
go c.receiveMessages()
} else {
c.mu.Unlock()
}
return nil
}
// reconnectStream reconnects the stream without starting a new receiver goroutine
func (c *Client) reconnectStream() error {
if !c.config.EnableReconnect {
return fmt.Errorf("reconnect disabled")
}
delay := c.config.InitialRetryDelay
attempt := 0
for {
select {
case <-c.ctx.Done():
return c.ctx.Err()
case <-time.After(delay):
attempt++
log.Debugf("Client %s reconnect attempt %d (delay: %v)", c.id, attempt, delay)
md := metadata.New(map[string]string{proto.HeaderId: c.id})
ctx := metadata.NewOutgoingContext(c.ctx, md)
stream, err := c.client.ConnectStream(ctx)
if err != nil {
log.Debugf("Client %s reconnect attempt %d failed: %v", c.id, attempt, err)
delay *= 2
if delay > c.config.MaxReconnectDelay {
delay = c.config.MaxReconnectDelay
}
continue
}
if _, err := stream.Header(); err != nil {
log.Debugf("Client %s reconnect header failed: %v", c.id, err)
delay *= 2
if delay > c.config.MaxReconnectDelay {
delay = c.config.MaxReconnectDelay
}
continue
}
c.mu.Lock()
c.stream = stream
c.connected = true
c.reconnectCount++
c.mu.Unlock()
log.Debugf("Client %s reconnected successfully (attempt %d, total reconnects: %d)",
c.id, attempt, c.reconnectCount)
return nil
}
}
}
// SendMessage sends an encrypted message to a remote peer using the Send RPC
func (c *Client) SendMessage(remotePeerID string, body []byte) error {
msg := &proto.EncryptedMessage{
Key: c.id,
RemoteKey: remotePeerID,
Body: body,
}
ctx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
defer cancel()
_, err := c.client.Send(ctx, msg)
if err != nil {
return fmt.Errorf("send message: %w", err)
}
return nil
}
// ReceiveMessage waits for and returns the next message
func (c *Client) ReceiveMessage() (*proto.EncryptedMessage, error) {
select {
case msg := <-c.msgChannel:
return msg, nil
case <-c.ctx.Done():
return nil, c.ctx.Err()
}
}
// Close closes the client connection
func (c *Client) Close() error {
c.cancel()
if c.conn != nil {
return c.conn.Close()
}
return nil
}
func (c *Client) receiveMessages() {
for {
c.mu.RLock()
stream := c.stream
c.mu.RUnlock()
if stream == nil {
return
}
msg, err := stream.Recv()
if err != nil {
// Check if context is cancelled before attempting reconnection
select {
case <-c.ctx.Done():
return
default:
}
c.mu.Lock()
c.connected = false
c.mu.Unlock()
log.Debugf("Client %s receive error: %v", c.id, err)
// Attempt reconnection if enabled
if c.config.EnableReconnect {
if reconnectErr := c.reconnectStream(); reconnectErr != nil {
log.Debugf("Client %s reconnection failed: %v", c.id, reconnectErr)
return
}
// Successfully reconnected, continue receiving
continue
}
// Reconnect disabled, exit
return
}
select {
case c.msgChannel <- msg:
case <-c.ctx.Done():
return
}
}
}
// IsConnected returns whether the client is currently connected
func (c *Client) IsConnected() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.connected
}
// GetReconnectCount returns the number of reconnections
func (c *Client) GetReconnectCount() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.reconnectCount
}
func parseServerURL(serverURL string, insecureSkipVerify bool) (string, []grpc.DialOption, error) {
serverURL = strings.TrimSpace(serverURL)
if serverURL == "" {
return "", nil, fmt.Errorf("server URL is empty")
}
var addr string
var opts []grpc.DialOption
if strings.HasPrefix(serverURL, "https://") {
addr = strings.TrimPrefix(serverURL, "https://")
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: insecureSkipVerify,
}
opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
} else if strings.HasPrefix(serverURL, "http://") {
addr = strings.TrimPrefix(serverURL, "http://")
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
addr = serverURL
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
if !strings.Contains(addr, ":") {
return "", nil, fmt.Errorf("server URL must include port")
}
return addr, opts, nil
}

View File

@@ -1,128 +0,0 @@
package main
import (
"context"
"fmt"
"net"
"os"
"os/exec"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/server"
)
func TestCLI_SingleMessage(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServer(t, ctx)
defer grpcServer.Stop()
cmd := exec.Command("go", "run", "main.go",
"-server", serverAddr,
"-pairs-per-sec", "3",
"-total-pairs", "5",
"-message-size", "50",
"-log-level", "warn")
output, err := cmd.CombinedOutput()
require.NoError(t, err, "CLI should execute successfully")
outputStr := string(output)
require.Contains(t, outputStr, "Load Test Report")
require.Contains(t, outputStr, "Total Pairs Sent: 5")
require.Contains(t, outputStr, "Successful Exchanges: 5")
t.Logf("Output:\n%s", outputStr)
}
func TestCLI_ContinuousExchange(t *testing.T) {
if testing.Short() {
t.Skip("Skipping continuous exchange CLI test in short mode")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServer(t, ctx)
defer grpcServer.Stop()
cmd := exec.Command("go", "run", "main.go",
"-server", serverAddr,
"-pairs-per-sec", "2",
"-total-pairs", "3",
"-message-size", "100",
"-exchange-duration", "3s",
"-message-interval", "100ms",
"-log-level", "warn")
output, err := cmd.CombinedOutput()
require.NoError(t, err, "CLI should execute successfully")
outputStr := string(output)
require.Contains(t, outputStr, "Load Test Report")
require.Contains(t, outputStr, "Total Pairs Sent: 3")
require.Contains(t, outputStr, "Successful Exchanges: 3")
t.Logf("Output:\n%s", outputStr)
}
func TestCLI_InvalidConfig(t *testing.T) {
tests := []struct {
name string
args []string
}{
{
name: "negative pairs",
args: []string{"-pairs-per-sec", "-1"},
},
{
name: "zero total pairs",
args: []string{"-total-pairs", "0"},
},
{
name: "negative message size",
args: []string{"-message-size", "-100"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
args := append([]string{"run", "main.go"}, tt.args...)
cmd := exec.Command("go", args...)
output, err := cmd.CombinedOutput()
require.Error(t, err, "Should fail with invalid config")
require.Contains(t, string(output), "Configuration error")
})
}
}
func startTestSignalServer(t *testing.T, ctx context.Context) (*grpc.Server, string) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
signalServer, err := server.NewServer(ctx, otel.Meter("cli-test"))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
go func() {
if err := grpcServer.Serve(listener); err != nil {
t.Logf("Server stopped: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
}
func TestMain(m *testing.M) {
os.Exit(m.Run())
}

View File

@@ -1,165 +0,0 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/signal/loadtest"
)
var (
serverURL string
pairsPerSecond int
totalPairs int
messageSize int
testDuration time.Duration
exchangeDuration time.Duration
messageInterval time.Duration
insecureSkipVerify bool
workerPoolSize int
channelBufferSize int
reportInterval int
logLevel string
enableReconnect bool
maxReconnectDelay time.Duration
initialRetryDelay time.Duration
)
func init() {
flag.StringVar(&serverURL, "server", "http://localhost:10000", "Signal server URL (http:// or https://)")
flag.IntVar(&pairsPerSecond, "pairs-per-sec", 10, "Number of peer pairs to create per second")
flag.IntVar(&totalPairs, "total-pairs", 100, "Total number of peer pairs to create")
flag.IntVar(&messageSize, "message-size", 100, "Size of test message in bytes")
flag.DurationVar(&testDuration, "test-duration", 0, "Maximum test duration (0 = unlimited)")
flag.DurationVar(&exchangeDuration, "exchange-duration", 0, "Duration for continuous message exchange per pair (0 = single message)")
flag.DurationVar(&messageInterval, "message-interval", 100*time.Millisecond, "Interval between messages in continuous mode")
flag.BoolVar(&insecureSkipVerify, "insecure-skip-verify", false, "Skip TLS certificate verification (use with self-signed certificates)")
flag.IntVar(&workerPoolSize, "worker-pool-size", 0, "Number of worker goroutines (0 = auto: pairs-per-sec * 2)")
flag.IntVar(&channelBufferSize, "channel-buffer-size", 0, "Channel buffer size (0 = auto: pairs-per-sec * 4)")
flag.IntVar(&reportInterval, "report-interval", 10000, "Report progress every N messages (0 = no periodic reports)")
flag.StringVar(&logLevel, "log-level", "info", "Log level (trace, debug, info, warn, error)")
flag.BoolVar(&enableReconnect, "enable-reconnect", true, "Enable automatic reconnection on connection loss")
flag.DurationVar(&maxReconnectDelay, "max-reconnect-delay", 30*time.Second, "Maximum delay between reconnection attempts")
flag.DurationVar(&initialRetryDelay, "initial-retry-delay", 100*time.Millisecond, "Initial delay before first reconnection attempt")
}
func main() {
flag.Parse()
level, err := log.ParseLevel(logLevel)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid log level: %v\n", err)
os.Exit(1)
}
log.SetLevel(level)
config := loadtest.LoadTestConfig{
ServerURL: serverURL,
PairsPerSecond: pairsPerSecond,
TotalPairs: totalPairs,
MessageSize: messageSize,
TestDuration: testDuration,
ExchangeDuration: exchangeDuration,
MessageInterval: messageInterval,
InsecureSkipVerify: insecureSkipVerify,
WorkerPoolSize: workerPoolSize,
ChannelBufferSize: channelBufferSize,
ReportInterval: reportInterval,
EnableReconnect: enableReconnect,
MaxReconnectDelay: maxReconnectDelay,
InitialRetryDelay: initialRetryDelay,
}
if err := validateConfig(config); err != nil {
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
flag.Usage()
os.Exit(1)
}
log.Infof("Signal Load Test Configuration:")
log.Infof(" Server URL: %s", config.ServerURL)
log.Infof(" Pairs per second: %d", config.PairsPerSecond)
log.Infof(" Total pairs: %d", config.TotalPairs)
log.Infof(" Message size: %d bytes", config.MessageSize)
if config.InsecureSkipVerify {
log.Warnf(" TLS certificate verification: DISABLED (insecure)")
}
if config.TestDuration > 0 {
log.Infof(" Test duration: %v", config.TestDuration)
}
if config.ExchangeDuration > 0 {
log.Infof(" Exchange duration: %v", config.ExchangeDuration)
log.Infof(" Message interval: %v", config.MessageInterval)
} else {
log.Infof(" Mode: Single message exchange")
}
if config.EnableReconnect {
log.Infof(" Reconnection: ENABLED")
log.Infof(" Initial retry delay: %v", config.InitialRetryDelay)
log.Infof(" Max reconnect delay: %v", config.MaxReconnectDelay)
}
fmt.Println()
// Set up signal handler for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
lt := loadtest.NewLoadTestWithContext(ctx, config)
// Run load test in a goroutine
done := make(chan error, 1)
go func() {
done <- lt.Run()
}()
// Wait for completion or signal
select {
case <-sigChan:
log.Warnf("\nReceived interrupt signal, shutting down gracefully...")
cancel()
// Wait a bit for graceful shutdown
select {
case <-done:
case <-time.After(5 * time.Second):
log.Warnf("Shutdown timeout, forcing exit")
}
case err := <-done:
if err != nil && err != context.Canceled {
log.Errorf("Load test failed: %v", err)
os.Exit(1)
}
}
metrics := lt.GetMetrics()
fmt.Println() // Add blank line before report
metrics.PrintReport()
}
func validateConfig(config loadtest.LoadTestConfig) error {
if config.ServerURL == "" {
return fmt.Errorf("server URL is required")
}
if config.PairsPerSecond <= 0 {
return fmt.Errorf("pairs-per-sec must be greater than 0")
}
if config.TotalPairs <= 0 {
return fmt.Errorf("total-pairs must be greater than 0")
}
if config.MessageSize <= 0 {
return fmt.Errorf("message-size must be greater than 0")
}
if config.MessageInterval <= 0 {
return fmt.Errorf("message-interval must be greater than 0")
}
return nil
}

View File

@@ -1,40 +0,0 @@
#!/bin/bash
set -e
echo "Building signal-loadtest binary..."
go build -o signal-loadtest
echo ""
echo "=== Test 1: Single message exchange (5 pairs) ==="
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 5 \
-total-pairs 5 \
-message-size 50 \
-log-level info
echo ""
echo "=== Test 2: Continuous exchange (3 pairs, 5 seconds) ==="
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 3 \
-total-pairs 3 \
-message-size 100 \
-exchange-duration 5s \
-message-interval 200ms \
-log-level info
echo ""
echo "=== Test 3: Progress reporting (10 pairs, 10s, report every 100 messages) ==="
./signal-loadtest \
-server http://localhost:10000 \
-pairs-per-sec 10 \
-total-pairs 10 \
-message-size 100 \
-exchange-duration 10s \
-message-interval 100ms \
-report-interval 100 \
-log-level info
echo ""
echo "All tests completed!"

View File

@@ -1,91 +0,0 @@
package loadtest
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/server"
)
func TestSignalLoadTest_SinglePair(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServer(t, ctx)
defer grpcServer.Stop()
sender, err := NewClient(serverAddr, "sender-peer-id")
require.NoError(t, err)
defer sender.Close()
receiver, err := NewClient(serverAddr, "receiver-peer-id")
require.NoError(t, err)
defer receiver.Close()
err = sender.Connect()
require.NoError(t, err, "Sender should connect successfully")
err = receiver.Connect()
require.NoError(t, err, "Receiver should connect successfully")
time.Sleep(100 * time.Millisecond)
testMessage := []byte("test message payload")
t.Log("Sending message from sender to receiver")
err = sender.SendMessage("receiver-peer-id", testMessage)
require.NoError(t, err, "Sender should send message successfully")
t.Log("Waiting for receiver to receive message")
receiveDone := make(chan struct{})
var msg *proto.EncryptedMessage
var receiveErr error
go func() {
msg, receiveErr = receiver.ReceiveMessage()
close(receiveDone)
}()
select {
case <-receiveDone:
require.NoError(t, receiveErr, "Receiver should receive message")
require.NotNil(t, msg, "Received message should not be nil")
require.Greater(t, len(msg.Body), 0, "Encrypted message body size should be greater than 0")
require.Equal(t, "sender-peer-id", msg.Key)
require.Equal(t, "receiver-peer-id", msg.RemoteKey)
t.Log("Message received successfully")
case <-time.After(5 * time.Second):
t.Fatal("Timeout waiting for message")
}
}
func startTestSignalServer(t *testing.T, ctx context.Context) (*grpc.Server, string) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
signalServer, err := server.NewServer(ctx, otel.Meter("test"))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
go func() {
if err := grpcServer.Serve(listener); err != nil {
t.Logf("Server stopped: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
}

View File

@@ -1,461 +0,0 @@
package loadtest
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
// LoadTestConfig configuration for the load test
type LoadTestConfig struct {
IDPrefix string
ServerURL string
PairsPerSecond int
TotalPairs int
MessageSize int
TestDuration time.Duration
ExchangeDuration time.Duration
MessageInterval time.Duration
RampUpDuration time.Duration
InsecureSkipVerify bool
WorkerPoolSize int
ChannelBufferSize int
ReportInterval int // Report progress every N messages (0 = no periodic reports)
EnableReconnect bool
MaxReconnectDelay time.Duration
InitialRetryDelay time.Duration
}
// LoadTestMetrics metrics collected during the load test
type LoadTestMetrics struct {
TotalPairsSent atomic.Int64
TotalMessagesExchanged atomic.Int64
TotalErrors atomic.Int64
SuccessfulExchanges atomic.Int64
FailedExchanges atomic.Int64
ActivePairs atomic.Int64
TotalReconnections atomic.Int64
mu sync.Mutex
latencies []time.Duration
startTime time.Time
endTime time.Time
}
// PeerPair represents a sender-receiver pair
type PeerPair struct {
sender *Client
receiver *Client
pairID int
}
// LoadTest manages the load test execution
type LoadTest struct {
config LoadTestConfig
metrics *LoadTestMetrics
ctx context.Context
cancel context.CancelFunc
reporterCtx context.Context
reporterCancel context.CancelFunc
}
// NewLoadTest creates a new load test instance
func NewLoadTest(config LoadTestConfig) *LoadTest {
ctx, cancel := context.WithCancel(context.Background())
return newLoadTestWithContext(ctx, cancel, config)
}
// NewLoadTestWithContext creates a new load test instance with a custom context
func NewLoadTestWithContext(ctx context.Context, config LoadTestConfig) *LoadTest {
ctx, cancel := context.WithCancel(ctx)
return newLoadTestWithContext(ctx, cancel, config)
}
func newLoadTestWithContext(ctx context.Context, cancel context.CancelFunc, config LoadTestConfig) *LoadTest {
reporterCtx, reporterCancel := context.WithCancel(ctx)
config.IDPrefix = fmt.Sprintf("%d-", time.Now().UnixNano())
return &LoadTest{
config: config,
metrics: &LoadTestMetrics{},
ctx: ctx,
cancel: cancel,
reporterCtx: reporterCtx,
reporterCancel: reporterCancel,
}
}
// Run executes the load test
func (lt *LoadTest) Run() error {
lt.metrics.startTime = time.Now()
defer func() {
lt.metrics.endTime = time.Now()
}()
exchangeInfo := "single message"
if lt.config.ExchangeDuration > 0 {
exchangeInfo = fmt.Sprintf("continuous for %v", lt.config.ExchangeDuration)
}
workerPoolSize := lt.config.WorkerPoolSize
if workerPoolSize == 0 {
workerPoolSize = lt.config.PairsPerSecond * 2
}
channelBufferSize := lt.config.ChannelBufferSize
if channelBufferSize == 0 {
channelBufferSize = lt.config.PairsPerSecond * 4
}
log.Infof("Starting load test: %d pairs/sec, %d total pairs, message size: %d bytes, exchange: %s",
lt.config.PairsPerSecond, lt.config.TotalPairs, lt.config.MessageSize, exchangeInfo)
log.Infof("Worker pool size: %d, channel buffer: %d", workerPoolSize, channelBufferSize)
var wg sync.WaitGroup
var reporterWg sync.WaitGroup
pairChan := make(chan int, channelBufferSize)
// Start progress reporter if configured
if lt.config.ReportInterval > 0 {
reporterWg.Add(1)
go lt.progressReporter(&reporterWg, lt.config.ReportInterval)
}
for i := 0; i < workerPoolSize; i++ {
wg.Add(1)
go lt.pairWorker(&wg, pairChan)
}
testCtx := lt.ctx
if lt.config.TestDuration > 0 {
var testCancel context.CancelFunc
testCtx, testCancel = context.WithTimeout(lt.ctx, lt.config.TestDuration)
defer testCancel()
}
ticker := time.NewTicker(time.Second / time.Duration(lt.config.PairsPerSecond))
defer ticker.Stop()
pairsCreated := 0
for pairsCreated < lt.config.TotalPairs {
select {
case <-testCtx.Done():
log.Infof("Test duration reached or context cancelled")
close(pairChan)
wg.Wait()
return testCtx.Err()
case <-ticker.C:
select {
case pairChan <- pairsCreated:
pairsCreated++
default:
log.Warnf("Worker pool saturated, skipping pair creation")
}
}
}
log.Infof("All %d pairs queued, waiting for completion...", pairsCreated)
close(pairChan)
wg.Wait()
// Cancel progress reporter context after all work is done and wait for it
lt.reporterCancel()
reporterWg.Wait()
return nil
}
func (lt *LoadTest) pairWorker(wg *sync.WaitGroup, pairChan <-chan int) {
defer wg.Done()
for pairID := range pairChan {
lt.metrics.ActivePairs.Add(1)
if err := lt.executePairExchange(pairID); err != nil {
lt.metrics.TotalErrors.Add(1)
lt.metrics.FailedExchanges.Add(1)
log.Debugf("Pair %d exchange failed: %v", pairID, err)
} else {
lt.metrics.SuccessfulExchanges.Add(1)
}
lt.metrics.ActivePairs.Add(-1)
lt.metrics.TotalPairsSent.Add(1)
}
}
func (lt *LoadTest) executePairExchange(pairID int) error {
senderID := fmt.Sprintf("%ssender-%d", lt.config.IDPrefix, pairID)
receiverID := fmt.Sprintf("%sreceiver-%d", lt.config.IDPrefix, pairID)
clientConfig := &ClientConfig{
InsecureSkipVerify: lt.config.InsecureSkipVerify,
EnableReconnect: lt.config.EnableReconnect,
MaxReconnectDelay: lt.config.MaxReconnectDelay,
InitialRetryDelay: lt.config.InitialRetryDelay,
}
sender, err := NewClientWithConfig(lt.config.ServerURL, senderID, clientConfig)
if err != nil {
return fmt.Errorf("create sender: %w", err)
}
defer func() {
sender.Close()
lt.metrics.TotalReconnections.Add(sender.GetReconnectCount())
}()
receiver, err := NewClientWithConfig(lt.config.ServerURL, receiverID, clientConfig)
if err != nil {
return fmt.Errorf("create receiver: %w", err)
}
defer func() {
receiver.Close()
lt.metrics.TotalReconnections.Add(receiver.GetReconnectCount())
}()
if err := sender.Connect(); err != nil {
return fmt.Errorf("sender connect: %w", err)
}
if err := receiver.Connect(); err != nil {
return fmt.Errorf("receiver connect: %w", err)
}
time.Sleep(50 * time.Millisecond)
testMessage := make([]byte, lt.config.MessageSize)
for i := range testMessage {
testMessage[i] = byte(i % 256)
}
if lt.config.ExchangeDuration > 0 {
return lt.continuousExchange(pairID, sender, receiver, receiverID, testMessage)
}
return lt.singleExchange(sender, receiver, receiverID, testMessage)
}
func (lt *LoadTest) singleExchange(sender, receiver *Client, receiverID string, testMessage []byte) error {
startTime := time.Now()
if err := sender.SendMessage(receiverID, testMessage); err != nil {
return fmt.Errorf("send message: %w", err)
}
receiveDone := make(chan error, 1)
go func() {
msg, err := receiver.ReceiveMessage()
if err != nil {
receiveDone <- err
return
}
if len(msg.Body) == 0 {
receiveDone <- fmt.Errorf("empty message body")
return
}
receiveDone <- nil
}()
select {
case err := <-receiveDone:
if err != nil {
return fmt.Errorf("receive message: %w", err)
}
latency := time.Since(startTime)
lt.recordLatency(latency)
lt.metrics.TotalMessagesExchanged.Add(1)
return nil
case <-time.After(5 * time.Second):
return fmt.Errorf("timeout waiting for message")
case <-lt.ctx.Done():
return lt.ctx.Err()
}
}
func (lt *LoadTest) continuousExchange(pairID int, sender, receiver *Client, receiverID string, testMessage []byte) error {
exchangeCtx, cancel := context.WithTimeout(lt.ctx, lt.config.ExchangeDuration)
defer cancel()
messageInterval := lt.config.MessageInterval
if messageInterval == 0 {
messageInterval = 100 * time.Millisecond
}
errChan := make(chan error, 1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := lt.receiverLoop(exchangeCtx, receiver, pairID); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
select {
case errChan <- err:
default:
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := lt.senderLoop(exchangeCtx, sender, receiverID, testMessage, messageInterval); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
select {
case errChan <- err:
default:
}
}
}()
wg.Wait()
select {
case err := <-errChan:
return err
default:
return nil
}
}
func (lt *LoadTest) senderLoop(ctx context.Context, sender *Client, receiverID string, message []byte, interval time.Duration) error {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
startTime := time.Now()
if err := sender.SendMessage(receiverID, message); err != nil {
lt.metrics.TotalErrors.Add(1)
log.Debugf("Send error: %v", err)
continue
}
lt.recordLatency(time.Since(startTime))
}
}
}
func (lt *LoadTest) receiverLoop(ctx context.Context, receiver *Client, pairID int) error {
for {
if ctx.Err() != nil {
return ctx.Err()
}
select {
case msg, ok := <-receiver.msgChannel:
if !ok {
return nil
}
if len(msg.Body) > 0 {
lt.metrics.TotalMessagesExchanged.Add(1)
}
case <-ctx.Done():
return ctx.Err()
case <-time.After(100 * time.Millisecond):
continue
}
}
}
func (lt *LoadTest) recordLatency(latency time.Duration) {
lt.metrics.mu.Lock()
defer lt.metrics.mu.Unlock()
lt.metrics.latencies = append(lt.metrics.latencies, latency)
}
// progressReporter prints periodic progress reports
func (lt *LoadTest) progressReporter(wg *sync.WaitGroup, interval int) {
defer wg.Done()
lastReported := int64(0)
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-lt.reporterCtx.Done():
return
case <-ticker.C:
currentMessages := lt.metrics.TotalMessagesExchanged.Load()
if currentMessages-lastReported >= int64(interval) {
elapsed := time.Since(lt.metrics.startTime)
activePairs := lt.metrics.ActivePairs.Load()
errors := lt.metrics.TotalErrors.Load()
reconnections := lt.metrics.TotalReconnections.Load()
var msgRate float64
if elapsed.Seconds() > 0 {
msgRate = float64(currentMessages) / elapsed.Seconds()
}
log.Infof("Progress: %d messages exchanged, %d active pairs, %d errors, %d reconnections, %.2f msg/sec, elapsed: %v",
currentMessages, activePairs, errors, reconnections, msgRate, elapsed.Round(time.Second))
lastReported = (currentMessages / int64(interval)) * int64(interval)
}
}
}
}
// Stop stops the load test
func (lt *LoadTest) Stop() {
lt.cancel()
lt.reporterCancel()
}
// GetMetrics returns the collected metrics
func (lt *LoadTest) GetMetrics() *LoadTestMetrics {
return lt.metrics
}
// PrintReport prints a summary report of the test results
func (m *LoadTestMetrics) PrintReport() {
duration := m.endTime.Sub(m.startTime)
fmt.Println("\n=== Load Test Report ===")
fmt.Printf("Test Duration: %v\n", duration)
fmt.Printf("Total Pairs Sent: %d\n", m.TotalPairsSent.Load())
fmt.Printf("Successful Exchanges: %d\n", m.SuccessfulExchanges.Load())
fmt.Printf("Failed Exchanges: %d\n", m.FailedExchanges.Load())
fmt.Printf("Total Messages Exchanged: %d\n", m.TotalMessagesExchanged.Load())
fmt.Printf("Total Errors: %d\n", m.TotalErrors.Load())
reconnections := m.TotalReconnections.Load()
if reconnections > 0 {
fmt.Printf("Total Reconnections: %d\n", reconnections)
}
if duration.Seconds() > 0 {
throughput := float64(m.SuccessfulExchanges.Load()) / duration.Seconds()
fmt.Printf("Throughput: %.2f pairs/sec\n", throughput)
}
m.mu.Lock()
latencies := m.latencies
m.mu.Unlock()
if len(latencies) > 0 {
var total time.Duration
minLatency := latencies[0]
maxLatency := latencies[0]
for _, lat := range latencies {
total += lat
if lat < minLatency {
minLatency = lat
}
if lat > maxLatency {
maxLatency = lat
}
}
avg := total / time.Duration(len(latencies))
fmt.Printf("\nLatency Statistics:\n")
fmt.Printf(" Min: %v\n", minLatency)
fmt.Printf(" Max: %v\n", maxLatency)
fmt.Printf(" Avg: %v\n", avg)
}
fmt.Println("========================")
}

View File

@@ -1,305 +0,0 @@
package loadtest
import (
"context"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/server"
)
func TestLoadTest_10PairsPerSecond(t *testing.T) {
if testing.Short() {
t.Skip("Skipping load test in short mode")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 10,
TotalPairs: 50,
MessageSize: 100,
TestDuration: 30 * time.Second,
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
require.Equal(t, int64(50), metrics.TotalPairsSent.Load(), "Should send all 50 pairs")
require.Greater(t, metrics.SuccessfulExchanges.Load(), int64(0), "Should have successful exchanges")
require.Equal(t, metrics.TotalMessagesExchanged.Load(), metrics.SuccessfulExchanges.Load(), "Messages exchanged should match successful exchanges")
}
func TestLoadTest_20PairsPerSecond(t *testing.T) {
if testing.Short() {
t.Skip("Skipping load test in short mode")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 20,
TotalPairs: 100,
MessageSize: 500,
TestDuration: 30 * time.Second,
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
require.Equal(t, int64(100), metrics.TotalPairsSent.Load(), "Should send all 100 pairs")
require.Greater(t, metrics.SuccessfulExchanges.Load(), int64(0), "Should have successful exchanges")
}
func TestLoadTest_SmallBurst(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 5,
TotalPairs: 10,
MessageSize: 50,
TestDuration: 10 * time.Second,
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
require.Equal(t, int64(10), metrics.TotalPairsSent.Load())
require.Greater(t, metrics.SuccessfulExchanges.Load(), int64(5), "At least 50% success rate")
require.Less(t, metrics.FailedExchanges.Load(), int64(5), "Less than 50% failure rate")
}
func TestLoadTest_ContinuousExchange_30Seconds(t *testing.T) {
if testing.Short() {
t.Skip("Skipping continuous exchange test in short mode")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 5,
TotalPairs: 10,
MessageSize: 100,
ExchangeDuration: 30 * time.Second,
MessageInterval: 100 * time.Millisecond,
TestDuration: 2 * time.Minute,
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
require.Equal(t, int64(10), metrics.TotalPairsSent.Load())
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(2000), "Should exchange many messages over 30 seconds")
}
func TestLoadTest_ContinuousExchange_10Minutes(t *testing.T) {
if testing.Short() {
t.Skip("Skipping long continuous exchange test in short mode")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 10,
TotalPairs: 20,
MessageSize: 200,
ExchangeDuration: 10 * time.Minute,
MessageInterval: 200 * time.Millisecond,
TestDuration: 15 * time.Minute,
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
require.Equal(t, int64(20), metrics.TotalPairsSent.Load())
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(50000), "Should exchange many messages over 10 minutes")
}
func TestLoadTest_ContinuousExchange_ShortBurst(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 3,
TotalPairs: 5,
MessageSize: 50,
ExchangeDuration: 3 * time.Second,
MessageInterval: 100 * time.Millisecond,
TestDuration: 10 * time.Second,
ReportInterval: 50, // Report every 50 messages for testing
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
require.Equal(t, int64(5), metrics.TotalPairsSent.Load())
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(100), "Should exchange multiple messages in 3 seconds")
require.Equal(t, int64(5), metrics.SuccessfulExchanges.Load(), "All pairs should complete successfully")
}
func TestLoadTest_ReconnectionConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startTestSignalServerForLoad(t, ctx)
defer grpcServer.Stop()
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 3,
TotalPairs: 5,
MessageSize: 50,
ExchangeDuration: 2 * time.Second,
MessageInterval: 200 * time.Millisecond,
TestDuration: 5 * time.Second,
EnableReconnect: true,
InitialRetryDelay: 100 * time.Millisecond,
MaxReconnectDelay: 2 * time.Second,
}
loadTest := NewLoadTest(config)
err := loadTest.Run()
require.NoError(t, err)
metrics := loadTest.GetMetrics()
metrics.PrintReport()
// Test should complete successfully with reconnection enabled
require.Equal(t, int64(5), metrics.TotalPairsSent.Load())
require.Greater(t, metrics.TotalMessagesExchanged.Load(), int64(0), "Should have exchanged messages")
require.Equal(t, int64(5), metrics.SuccessfulExchanges.Load(), "All pairs should complete successfully")
// Reconnections counter should exist (even if zero for this stable test)
reconnections := metrics.TotalReconnections.Load()
require.GreaterOrEqual(t, reconnections, int64(0), "Reconnections metric should be tracked")
}
func BenchmarkLoadTest_Throughput(b *testing.B) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
grpcServer, serverAddr := startBenchSignalServer(b, ctx)
defer grpcServer.Stop()
b.Run("5pairs-per-sec", func(b *testing.B) {
config := LoadTestConfig{
ServerURL: serverAddr,
PairsPerSecond: 5,
TotalPairs: b.N,
MessageSize: 100,
}
loadTest := NewLoadTest(config)
b.ResetTimer()
_ = loadTest.Run()
b.StopTimer()
metrics := loadTest.GetMetrics()
b.ReportMetric(float64(metrics.SuccessfulExchanges.Load()), "successful")
b.ReportMetric(float64(metrics.FailedExchanges.Load()), "failed")
})
}
func startTestSignalServerForLoad(t *testing.T, ctx context.Context) (*grpc.Server, string) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
signalServer, err := server.NewServer(ctx, otel.Meter("test"))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
go func() {
if err := grpcServer.Serve(listener); err != nil {
t.Logf("Server stopped: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
}
func startBenchSignalServer(b *testing.B, ctx context.Context) (*grpc.Server, string) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(b, err)
grpcServer := grpc.NewServer()
signalServer, err := server.NewServer(ctx, otel.Meter("bench"))
require.NoError(b, err)
proto.RegisterSignalExchangeServer(grpcServer, signalServer)
go func() {
if err := grpcServer.Serve(listener); err != nil {
b.Logf("Server stopped: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
return grpcServer, fmt.Sprintf("http://%s", listener.Addr().String())
}