mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 07:26:14 -04:00
Compare commits
2 Commits
feat/sync-
...
test/updat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca46fe215a | ||
|
|
e5f926fa6d |
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.22.2
|
||||
FROM alpine:3.22.0
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -357,21 +356,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
||||
cmd.Println("")
|
||||
|
||||
if !noBrowser {
|
||||
if err := openBrowser(verificationURIComplete); err != nil {
|
||||
if err := open.Run(verificationURIComplete); err != nil {
|
||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
|
||||
func openBrowser(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
return open.Run(url)
|
||||
}
|
||||
|
||||
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
|
||||
@@ -400,6 +400,7 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
|
||||
return ""
|
||||
}
|
||||
|
||||
// Include action in the ipset name to prevent squashing rules with different actions
|
||||
actionSuffix := ""
|
||||
if action == firewall.ActionDrop {
|
||||
actionSuffix = "-drop"
|
||||
|
||||
@@ -23,5 +23,4 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -150,11 +150,6 @@ func (t *WGTunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *WGTunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
func routesToString(routes []string) string {
|
||||
return strings.Join(routes, ";")
|
||||
}
|
||||
|
||||
@@ -154,8 +154,3 @@ func (t *TunDevice) assignAddr() error {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -144,8 +144,3 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -179,8 +179,3 @@ func (t *TunKernelDevice) assignAddr() error {
|
||||
func (t *TunKernelDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns nil for kernel mode devices
|
||||
func (t *TunKernelDevice) GetICEBind() EndpointManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ type Bind interface {
|
||||
conn.Bind
|
||||
GetICEMux() (*udpmux.UniversalUDPMuxDefault, error)
|
||||
ActivityRecorder() *bind.ActivityRecorder
|
||||
EndpointManager
|
||||
}
|
||||
|
||||
type TunNetstackDevice struct {
|
||||
@@ -156,8 +155,3 @@ func (t *TunNetstackDevice) Device() *device.Device {
|
||||
func (t *TunNetstackDevice) GetNet() *netstack.Net {
|
||||
return t.net
|
||||
}
|
||||
|
||||
// GetICEBind returns the bind instance
|
||||
func (t *TunNetstackDevice) GetICEBind() EndpointManager {
|
||||
return t.bind
|
||||
}
|
||||
|
||||
@@ -146,8 +146,3 @@ func (t *USPDevice) assignAddr() error {
|
||||
func (t *USPDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *USPDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -185,8 +185,3 @@ func (t *TunDevice) assignAddr() error {
|
||||
func (t *TunDevice) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetICEBind returns the ICEBind instance
|
||||
func (t *TunDevice) GetICEBind() EndpointManager {
|
||||
return t.iceBind
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
// EndpointManager manages fake IP to connection mappings for userspace bind implementations.
|
||||
// Implemented by bind.ICEBind and bind.RelayBindJS.
|
||||
type EndpointManager interface {
|
||||
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
|
||||
RemoveEndpoint(fakeIP netip.Addr)
|
||||
}
|
||||
@@ -21,5 +21,4 @@ type WGTunDevice interface {
|
||||
FilteredDevice() *device.FilteredDevice
|
||||
Device() *wgdevice.Device
|
||||
GetNet() *netstack.Net
|
||||
GetICEBind() device.EndpointManager
|
||||
}
|
||||
|
||||
@@ -80,17 +80,6 @@ func (w *WGIface) GetProxy() wgproxy.Proxy {
|
||||
return w.wgProxyFactory.GetProxy()
|
||||
}
|
||||
|
||||
// GetBind returns the EndpointManager userspace bind mode.
|
||||
func (w *WGIface) GetBind() device.EndpointManager {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
if w.tun == nil {
|
||||
return nil
|
||||
}
|
||||
return w.tun.GetICEBind()
|
||||
}
|
||||
|
||||
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
|
||||
func (w *WGIface) IsUserspaceBind() bool {
|
||||
return w.userspaceBind
|
||||
|
||||
@@ -29,6 +29,11 @@ type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||
}
|
||||
|
||||
type protoMatch struct {
|
||||
ips map[string]int
|
||||
policyID []byte
|
||||
}
|
||||
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
@@ -81,14 +86,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
rules := networkMap.FirewallRules
|
||||
rules, squashedProtocols := d.squashAcceptRules(networkMap)
|
||||
|
||||
enableSSH := networkMap.PeerConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig != nil &&
|
||||
networkMap.PeerConfig.SshConfig.SshEnabled
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
|
||||
enableSSH = enableSSH && !ok
|
||||
}
|
||||
|
||||
// If SSH enabled, add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
|
||||
// if TCP protocol rules not squashed and SSH enabled
|
||||
// we add default firewall rule which accepts connection to any peer
|
||||
// in the network by SSH (TCP 22 port).
|
||||
if enableSSH {
|
||||
rules = append(rules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
@@ -356,6 +368,145 @@ func (d *DefaultManager) getPeerRuleID(
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
}
|
||||
|
||||
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
|
||||
// to all peers in the network map to one rule which just accepts that type of the traffic.
|
||||
//
|
||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
||||
// but other has port definitions or has drop policy.
|
||||
func (d *DefaultManager) squashAcceptRules(
|
||||
networkMap *mgmProto.NetworkMap,
|
||||
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
|
||||
totalIPs := 0
|
||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||
for range p.AllowedIps {
|
||||
totalIPs++
|
||||
}
|
||||
}
|
||||
|
||||
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||
|
||||
// trace which type of protocols was squashed
|
||||
squashedRules := []*mgmProto.FirewallRule{}
|
||||
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
|
||||
|
||||
// this function we use to do calculation, can we squash the rules by protocol or not.
|
||||
// We summ amount of Peers IP for given protocol we found in original rules list.
|
||||
// But we zeroed the IP's for protocol if:
|
||||
// 1. Any of the rule has DROP action type.
|
||||
// 2. Any of rule contains Port.
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||
|
||||
if hasPortRestrictions {
|
||||
// Don't squash rules with port restrictions
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
// store the first encountered PolicyID for this protocol
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
}
|
||||
|
||||
// special case, when we receive this all network IP address
|
||||
// it means that rules for that protocol was already optimized on the
|
||||
// management side
|
||||
if r.PeerIP == "0.0.0.0" {
|
||||
squashedRules = append(squashedRules, r)
|
||||
squashedProtocols[r.Protocol] = struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
ipset := protocols[r.Protocol].ips
|
||||
|
||||
if _, ok := ipset[r.PeerIP]; ok {
|
||||
return
|
||||
}
|
||||
ipset[r.PeerIP] = i
|
||||
}
|
||||
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
// calculate squash for different directions
|
||||
if r.Direction == mgmProto.RuleDirection_IN {
|
||||
addRuleToCalculationMap(i, r, in)
|
||||
} else {
|
||||
addRuleToCalculationMap(i, r, out)
|
||||
}
|
||||
}
|
||||
|
||||
// order of squashing by protocol is important
|
||||
// only for their first element ALL, it must be done first
|
||||
protocolOrders := []mgmProto.RuleProtocol{
|
||||
mgmProto.RuleProtocol_ALL,
|
||||
mgmProto.RuleProtocol_ICMP,
|
||||
mgmProto.RuleProtocol_TCP,
|
||||
mgmProto.RuleProtocol_UDP,
|
||||
}
|
||||
|
||||
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
||||
for _, protocol := range protocolOrders {
|
||||
match, ok := matches[protocol]
|
||||
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
||||
// don't squash if :
|
||||
// 1. Rules not cover all peers in the network
|
||||
// 2. Rules cover only one peer in the network.
|
||||
continue
|
||||
}
|
||||
|
||||
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
|
||||
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
|
||||
PeerIP: "0.0.0.0",
|
||||
Direction: direction,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: protocol,
|
||||
PolicyID: match.policyID,
|
||||
})
|
||||
squashedProtocols[protocol] = struct{}{}
|
||||
|
||||
if protocol == mgmProto.RuleProtocol_ALL {
|
||||
// if we have ALL traffic type squashed rule
|
||||
// it allows all other type of traffic, so we can stop processing
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
squash(in, mgmProto.RuleDirection_IN)
|
||||
squash(out, mgmProto.RuleDirection_OUT)
|
||||
|
||||
// if all protocol was squashed everything is allow and we can ignore all other rules
|
||||
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
|
||||
return squashedRules, squashedProtocols
|
||||
}
|
||||
|
||||
if len(squashedRules) == 0 {
|
||||
return networkMap.FirewallRules, squashedProtocols
|
||||
}
|
||||
|
||||
var rules []*mgmProto.FirewallRule
|
||||
// filter out rules which was squashed from final list
|
||||
// if we also have other not squashed rules.
|
||||
for i, r := range networkMap.FirewallRules {
|
||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
||||
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
return append(rules, squashedRules...), squashedProtocols
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
|
||||
@@ -188,6 +188,492 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, 2, len(rules))
|
||||
|
||||
r := rules[0]
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
|
||||
r = rules[1]
|
||||
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_ALL,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_OUT,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []*mgmProto.FirewallRule
|
||||
expectedCount int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "should not squash rules with port ranges",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with specific ports",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with legacy port field",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with legacy port field should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with DROP action",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with DROP action should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should squash rules without port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||
},
|
||||
{
|
||||
name: "mixed rules should not squash protocol with port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "TCP should not be squashed because one rule has port restrictions",
|
||||
},
|
||||
{
|
||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
// TCP rules with port restrictions - should NOT be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
// UDP rules without port restrictions - SHOULD be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: tt.rules,
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
|
||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||
|
||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||
if tt.expectedCount == 1 {
|
||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -14,9 +14,6 @@ type WGIface interface {
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addWgShow() error {
|
||||
if g.statusRecorder == nil {
|
||||
return fmt.Errorf("no status recorder available for wg show")
|
||||
}
|
||||
result, err := g.statusRecorder.PeersStatus()
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -31,7 +31,6 @@ const (
|
||||
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
|
||||
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
|
||||
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
|
||||
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
|
||||
systemdDbusResolvConfModeForeign = "foreign"
|
||||
|
||||
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
|
||||
@@ -103,11 +102,6 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
||||
log.Warnf("failed to set DNSSEC to 'no': %v", err)
|
||||
}
|
||||
|
||||
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
|
||||
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
|
||||
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
searchDomains []string
|
||||
matchDomains []string
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cache struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*cacheEntry
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
ip4Addrs []netip.Addr
|
||||
ip6Addrs []netip.Addr
|
||||
}
|
||||
|
||||
func newCache() *cache {
|
||||
return &cache{
|
||||
records: make(map[string]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
entry, exists := c.records[normalizeDomain(domain)]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
return slices.Clone(entry.ip4Addrs), true
|
||||
case dns.TypeAAAA:
|
||||
return slices.Clone(entry.ip6Addrs), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
norm := normalizeDomain(domain)
|
||||
entry, exists := c.records[norm]
|
||||
if !exists {
|
||||
entry = &cacheEntry{}
|
||||
c.records[norm] = entry
|
||||
}
|
||||
|
||||
switch reqType {
|
||||
case dns.TypeA:
|
||||
entry.ip4Addrs = slices.Clone(addrs)
|
||||
case dns.TypeAAAA:
|
||||
entry.ip6Addrs = slices.Clone(addrs)
|
||||
}
|
||||
}
|
||||
|
||||
// unset removes cached entries for the given domain and request type.
|
||||
func (c *cache) unset(domain string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
delete(c.records, normalizeDomain(domain))
|
||||
}
|
||||
|
||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
||||
// lowercase and fully-qualified (with trailing dot).
|
||||
func normalizeDomain(domain string) string {
|
||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
||||
return dns.Fqdn(strings.ToLower(domain))
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package dnsfwd
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
||||
t.Helper()
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse addr %s: %v", s, err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestCacheNormalization(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
// Mixed case, without trailing dot
|
||||
domainInput := "ExAmPlE.CoM"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
||||
|
||||
// Lookup with lower, with trailing dot
|
||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Lookup with different casing again
|
||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSeparateTypes(t *testing.T) {
|
||||
c := newCache()
|
||||
|
||||
domain := "test.local"
|
||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
||||
|
||||
c.set(domain, 1 /* A */, ipv4)
|
||||
c.set(domain, 28 /* AAAA */, ipv6)
|
||||
|
||||
got4, ok4 := c.get(domain, 1)
|
||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
||||
}
|
||||
|
||||
got6, ok6 := c.get(domain, 28)
|
||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
||||
c := newCache()
|
||||
domain := "clone.test"
|
||||
|
||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
||||
c.set(domain, 1, src)
|
||||
|
||||
// Mutate source slice; cache should be unaffected
|
||||
src[0] = mustAddr(t, "9.9.9.9")
|
||||
|
||||
got, ok := c.get(domain, 1)
|
||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
||||
}
|
||||
|
||||
// Mutate returned slice; internal cache should remain unchanged
|
||||
got[0] = mustAddr(t, "4.4.4.4")
|
||||
got2, ok2 := c.get(domain, 1)
|
||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := newCache()
|
||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +46,6 @@ type DNSForwarder struct {
|
||||
fwdEntries []*ForwarderEntry
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
@@ -57,7 +56,6 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
firewall: firewall,
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,39 +103,10 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
// remove cache entries for domains that no longer appear
|
||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
||||
|
||||
f.fwdEntries = entries
|
||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||
}
|
||||
|
||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
||||
// in the old list but not present in the new list.
|
||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
||||
if f.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
newSet := make(map[string]struct{}, len(newEntries))
|
||||
for _, e := range newEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
||||
}
|
||||
|
||||
for _, e := range oldEntries {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
pattern := e.Domain.PunycodeString()
|
||||
if _, ok := newSet[pattern]; !ok {
|
||||
f.cache.unset(pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -202,7 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
|
||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
f.cache.set(domain, question.Qtype, ips)
|
||||
|
||||
return resp
|
||||
}
|
||||
@@ -314,69 +282,29 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
||||
func (f *DNSForwarder) handleDNSError(
|
||||
ctx context.Context,
|
||||
w dns.ResponseWriter,
|
||||
question dns.Question,
|
||||
resp *dns.Msg,
|
||||
domain string,
|
||||
err error,
|
||||
) {
|
||||
// Default to SERVFAIL; override below when appropriate.
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
|
||||
qType := question.Qtype
|
||||
qTypeName := dns.TypeToString[qType]
|
||||
|
||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
}
|
||||
f.cache.set(domain, question.Qtype, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Upstream failed but we might have a cached answer—serve it if present.
|
||||
if ips, ok := f.cache.get(domain, qType); ok {
|
||||
if len(ips) > 0 {
|
||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||
f.addIPsToResponse(resp, domain, ips)
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
||||
}
|
||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
}
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// No cache. Log with or without the server field for more context.
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
||||
} else {
|
||||
default:
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
// Write final failure response.
|
||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
log.Errorf("failed to write failure DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -648,95 +648,6 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||
}
|
||||
|
||||
// Ensures that when the first query succeeds and populates the cache,
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("1.2.3.4")
|
||||
|
||||
// First call resolves successfully and populates cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
// Second call fails upstream; forwarder should serve from cache
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
// First query: populate cache
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Second query: serve from cache after upstream failure
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp, "expected response to be written")
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
require.NoError(t, err)
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
ip := netip.MustParseAddr("9.8.7.6")
|
||||
|
||||
// Initial resolution with mixed case to populate cache
|
||||
mixedQuery := "ExAmPlE.CoM"
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
||||
Return([]netip.Addr{ip}, nil).Once()
|
||||
|
||||
q1 := &dns.Msg{}
|
||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||
w1 := &test.MockResponseWriter{}
|
||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
||||
require.NotNil(t, resp1)
|
||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||
require.Len(t, resp1.Answer, 1)
|
||||
|
||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
||||
|
||||
q2 := &dns.Msg{}
|
||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||
var writtenResp *dns.Msg
|
||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
||||
_ = forwarder.handleDNSQuery(w2, q2)
|
||||
|
||||
require.NotNil(t, writtenResp)
|
||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
||||
require.Len(t, writtenResp.Answer, 1)
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
// Test complex overlapping pattern scenarios
|
||||
mockFirewall := &MockFirewall{}
|
||||
|
||||
@@ -40,6 +40,7 @@ type Manager struct {
|
||||
fwRules []firewall.Rule
|
||||
tcpRules []firewall.Rule
|
||||
dnsForwarder *DNSForwarder
|
||||
port uint16
|
||||
}
|
||||
|
||||
func ListenPort() uint16 {
|
||||
@@ -48,16 +49,11 @@ func ListenPort() uint16 {
|
||||
return listenPort
|
||||
}
|
||||
|
||||
func SetListenPort(port uint16) {
|
||||
listenPortMu.Lock()
|
||||
listenPort = port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
port: port,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +67,12 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.port > 0 {
|
||||
listenPortMu.Lock()
|
||||
listenPort = m.port
|
||||
listenPortMu.Unlock()
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
|
||||
@@ -1849,10 +1849,6 @@ func (e *Engine) updateDNSForwarder(
|
||||
return
|
||||
}
|
||||
|
||||
if forwarderPort > 0 {
|
||||
dnsfwd.SetListenPort(forwarderPort)
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
@@ -1866,7 +1862,7 @@ func (e *Engine) updateDNSForwarder(
|
||||
if len(fwdEntries) > 0 {
|
||||
switch {
|
||||
case e.dnsForwardMgr == nil:
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
@@ -1896,7 +1892,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// lazyConn detects activity when WireGuard attempts to send packets.
|
||||
// It does not deliver packets, only signals that activity occurred.
|
||||
type lazyConn struct {
|
||||
activityCh chan struct{}
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// newLazyConn creates a new lazyConn for activity detection.
|
||||
func newLazyConn() *lazyConn {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &lazyConn{
|
||||
activityCh: make(chan struct{}, 1),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Read blocks until the connection is closed.
|
||||
func (c *lazyConn) Read(_ []byte) (n int, err error) {
|
||||
<-c.ctx.Done()
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Write signals activity detection when ICEBind routes packets to this endpoint.
|
||||
func (c *lazyConn) Write(b []byte) (n int, err error) {
|
||||
if c.ctx.Err() != nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
select {
|
||||
case c.activityCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// ActivityChan returns the channel that signals when activity is detected.
|
||||
func (c *lazyConn) ActivityChan() <-chan struct{} {
|
||||
return c.activityCh
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *lazyConn) Close() error {
|
||||
c.cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local address.
|
||||
func (c *lazyConn) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote address.
|
||||
func (c *lazyConn) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: lazyBindPort}
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines.
|
||||
func (c *lazyConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
func (c *lazyConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for future Write calls.
|
||||
func (c *lazyConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
@@ -11,27 +11,26 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
// UDPListener uses UDP sockets for activity detection in kernel mode.
|
||||
type UDPListener struct {
|
||||
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
|
||||
type Listener struct {
|
||||
wgIface WgInterface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
conn *net.UDPConn
|
||||
endpoint *net.UDPAddr
|
||||
done sync.Mutex
|
||||
|
||||
isClosed atomic.Bool
|
||||
isClosed atomic.Bool // use to avoid error log when closing the listener
|
||||
}
|
||||
|
||||
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
||||
func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener, error) {
|
||||
d := &UDPListener{
|
||||
func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
|
||||
d := &Listener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
}
|
||||
|
||||
conn, err := d.newConn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create UDP connection: %v", err)
|
||||
return nil, fmt.Errorf("failed to creating activity listener: %v", err)
|
||||
}
|
||||
d.conn = conn
|
||||
d.endpoint = conn.LocalAddr().(*net.UDPAddr)
|
||||
@@ -39,14 +38,12 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
|
||||
if err := d.createEndpoint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.done.Lock()
|
||||
cfg.Log.Infof("created activity listener: %s", d.conn.LocalAddr().(*net.UDPAddr).String())
|
||||
cfg.Log.Infof("created activity listener: %s", conn.LocalAddr().(*net.UDPAddr).String())
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
||||
func (d *UDPListener) ReadPackets() {
|
||||
func (d *Listener) ReadPackets() {
|
||||
for {
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
if err != nil {
|
||||
@@ -67,17 +64,15 @@ func (d *UDPListener) ReadPackets() {
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
if err := d.removeEndpoint(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
// Ignore close error as it may return "use of closed network connection" if already closed.
|
||||
_ = d.conn.Close()
|
||||
_ = d.conn.Close() // do not care err because some cases it will return "use of closed network connection"
|
||||
d.done.Unlock()
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *UDPListener) Close() {
|
||||
func (d *Listener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||
d.isClosed.Store(true)
|
||||
|
||||
@@ -87,12 +82,16 @@ func (d *UDPListener) Close() {
|
||||
d.done.Lock()
|
||||
}
|
||||
|
||||
func (d *UDPListener) createEndpoint() error {
|
||||
func (d *Listener) removeEndpoint() error {
|
||||
return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
|
||||
}
|
||||
|
||||
func (d *Listener) createEndpoint() error {
|
||||
d.peerCfg.Log.Debugf("creating lazy endpoint: %s", d.endpoint.String())
|
||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, d.endpoint, nil)
|
||||
}
|
||||
|
||||
func (d *UDPListener) newConn() (*net.UDPConn, error) {
|
||||
func (d *Listener) newConn() (*net.UDPConn, error) {
|
||||
addr := &net.UDPAddr{
|
||||
Port: 0,
|
||||
IP: listenIP,
|
||||
@@ -1,127 +0,0 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
type bindProvider interface {
|
||||
GetBind() device.EndpointManager
|
||||
}
|
||||
|
||||
const (
|
||||
// lazyBindPort is an obscure port used for lazy peer endpoints to avoid confusion with real peers.
|
||||
// The actual routing is done via fakeIP in ICEBind, not by this port.
|
||||
lazyBindPort = 17473
|
||||
)
|
||||
|
||||
// BindListener uses lazyConn with bind implementations for direct data passing in userspace bind mode.
|
||||
type BindListener struct {
|
||||
wgIface WgInterface
|
||||
peerCfg lazyconn.PeerConfig
|
||||
done sync.WaitGroup
|
||||
|
||||
lazyConn *lazyConn
|
||||
bind device.EndpointManager
|
||||
fakeIP netip.Addr
|
||||
}
|
||||
|
||||
// NewBindListener creates a listener that passes data directly through bind using LazyConn.
|
||||
// It automatically derives a unique fake IP from the peer's NetBird IP in the 127.2.x.x range.
|
||||
func NewBindListener(wgIface WgInterface, bind device.EndpointManager, cfg lazyconn.PeerConfig) (*BindListener, error) {
|
||||
fakeIP, err := deriveFakeIP(wgIface, cfg.AllowedIPs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive fake IP: %w", err)
|
||||
}
|
||||
|
||||
d := &BindListener{
|
||||
wgIface: wgIface,
|
||||
peerCfg: cfg,
|
||||
bind: bind,
|
||||
fakeIP: fakeIP,
|
||||
}
|
||||
|
||||
if err := d.setupLazyConn(); err != nil {
|
||||
return nil, fmt.Errorf("setup lazy connection: %v", err)
|
||||
}
|
||||
|
||||
d.done.Add(1)
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// deriveFakeIP creates a deterministic fake IP for bind mode based on peer's NetBird IP.
|
||||
// Maps peer IP 100.64.x.y to fake IP 127.2.x.y (similar to relay proxy using 127.1.x.y).
|
||||
// It finds the peer's actual NetBird IP by checking which allowedIP is in the same subnet as our WG interface.
|
||||
func deriveFakeIP(wgIface WgInterface, allowedIPs []netip.Prefix) (netip.Addr, error) {
|
||||
if len(allowedIPs) == 0 {
|
||||
return netip.Addr{}, fmt.Errorf("no allowed IPs for peer")
|
||||
}
|
||||
|
||||
ourNetwork := wgIface.Address().Network
|
||||
|
||||
var peerIP netip.Addr
|
||||
for _, allowedIP := range allowedIPs {
|
||||
ip := allowedIP.Addr()
|
||||
if !ip.Is4() {
|
||||
continue
|
||||
}
|
||||
if ourNetwork.Contains(ip) {
|
||||
peerIP = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !peerIP.IsValid() {
|
||||
return netip.Addr{}, fmt.Errorf("no peer NetBird IP found in allowed IPs")
|
||||
}
|
||||
|
||||
octets := peerIP.As4()
|
||||
fakeIP := netip.AddrFrom4([4]byte{127, 2, octets[2], octets[3]})
|
||||
return fakeIP, nil
|
||||
}
|
||||
|
||||
func (d *BindListener) setupLazyConn() error {
|
||||
d.lazyConn = newLazyConn()
|
||||
d.bind.SetEndpoint(d.fakeIP, d.lazyConn)
|
||||
|
||||
endpoint := &net.UDPAddr{
|
||||
IP: d.fakeIP.AsSlice(),
|
||||
Port: lazyBindPort,
|
||||
}
|
||||
return d.wgIface.UpdatePeer(d.peerCfg.PublicKey, d.peerCfg.AllowedIPs, 0, endpoint, nil)
|
||||
}
|
||||
|
||||
// ReadPackets blocks until activity is detected on the LazyConn or the listener is closed.
|
||||
func (d *BindListener) ReadPackets() {
|
||||
select {
|
||||
case <-d.lazyConn.ActivityChan():
|
||||
d.peerCfg.Log.Infof("activity detected via LazyConn")
|
||||
case <-d.lazyConn.ctx.Done():
|
||||
d.peerCfg.Log.Infof("exit from activity listener")
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
_ = d.lazyConn.Close()
|
||||
d.bind.RemoveEndpoint(d.fakeIP)
|
||||
d.done.Done()
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *BindListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
||||
|
||||
if err := d.lazyConn.Close(); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to close LazyConn: %s", err)
|
||||
}
|
||||
|
||||
d.done.Wait()
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
func isBindListenerPlatform() bool {
|
||||
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||
}
|
||||
|
||||
// mockEndpointManager implements device.EndpointManager for testing
|
||||
type mockEndpointManager struct {
|
||||
endpoints map[netip.Addr]net.Conn
|
||||
}
|
||||
|
||||
func newMockEndpointManager() *mockEndpointManager {
|
||||
return &mockEndpointManager{
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockEndpointManager) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||
m.endpoints[fakeIP] = conn
|
||||
}
|
||||
|
||||
func (m *mockEndpointManager) RemoveEndpoint(fakeIP netip.Addr) {
|
||||
delete(m.endpoints, fakeIP)
|
||||
}
|
||||
|
||||
func (m *mockEndpointManager) GetEndpoint(fakeIP netip.Addr) net.Conn {
|
||||
return m.endpoints[fakeIP]
|
||||
}
|
||||
|
||||
// MockWGIfaceBind mocks WgInterface with bind support
|
||||
type MockWGIfaceBind struct {
|
||||
endpointMgr *mockEndpointManager
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) IsUserspaceBind() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) Address() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
|
||||
return m.endpointMgr
|
||||
}
|
||||
|
||||
func TestBindListener_Creation(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedFakeIP := netip.MustParseAddr("127.2.0.2")
|
||||
conn := mockEndpointMgr.GetEndpoint(expectedFakeIP)
|
||||
require.NotNil(t, conn, "Endpoint should be registered in mock endpoint manager")
|
||||
|
||||
_, ok := conn.(*lazyConn)
|
||||
assert.True(t, ok, "Registered endpoint should be a lazyConn")
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindListener_ActivityDetection(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
activityDetected := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(activityDetected)
|
||||
}()
|
||||
|
||||
fakeIP := listener.fakeIP
|
||||
conn := mockEndpointMgr.GetEndpoint(fakeIP)
|
||||
require.NotNil(t, conn, "Endpoint should be registered")
|
||||
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-activityDetected:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity detection")
|
||||
}
|
||||
|
||||
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity detection")
|
||||
}
|
||||
|
||||
func TestBindListener_Close(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewBindListener(mockIface, mockEndpointMgr, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
fakeIP := listener.fakeIP
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
|
||||
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after Close")
|
||||
}
|
||||
|
||||
func TestManager_BindMode(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
mgr := NewManager(mockIface)
|
||||
defer mgr.Close()
|
||||
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
err := mgr.MonitorPeerActivity(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
listener, exists := mgr.GetPeerListener(cfg.PeerConnID)
|
||||
require.True(t, exists, "Peer listener should be found")
|
||||
|
||||
bindListener, ok := listener.(*BindListener)
|
||||
require.True(t, ok, "Listener should be BindListener, got %T", listener)
|
||||
|
||||
fakeIP := bindListener.fakeIP
|
||||
conn := mockEndpointMgr.GetEndpoint(fakeIP)
|
||||
require.NotNil(t, conn, "Endpoint should be registered")
|
||||
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notification")
|
||||
}
|
||||
|
||||
assert.Nil(t, mockEndpointMgr.GetEndpoint(fakeIP), "Endpoint should be removed after activity")
|
||||
}
|
||||
|
||||
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||
if !isBindListenerPlatform() {
|
||||
t.Skip("BindListener only used on Windows/JS platforms")
|
||||
}
|
||||
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
|
||||
peer1 := &MocPeer{PeerID: "testPeer1"}
|
||||
peer2 := &MocPeer{PeerID: "testPeer2"}
|
||||
mgr := NewManager(mockIface)
|
||||
defer mgr.Close()
|
||||
|
||||
cfg1 := lazyconn.PeerConfig{
|
||||
PublicKey: peer1.PeerID,
|
||||
PeerConnID: peer1.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
cfg2 := lazyconn.PeerConfig{
|
||||
PublicKey: peer2.PeerID,
|
||||
PeerConnID: peer2.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.3/32")},
|
||||
Log: log.WithField("peer", "testPeer2"),
|
||||
}
|
||||
|
||||
err := mgr.MonitorPeerActivity(cfg1)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.MonitorPeerActivity(cfg2)
|
||||
require.NoError(t, err)
|
||||
|
||||
listener1, exists := mgr.GetPeerListener(cfg1.PeerConnID)
|
||||
require.True(t, exists, "Peer1 listener should be found")
|
||||
bindListener1 := listener1.(*BindListener)
|
||||
|
||||
listener2, exists := mgr.GetPeerListener(cfg2.PeerConnID)
|
||||
require.True(t, exists, "Peer2 listener should be found")
|
||||
bindListener2 := listener2.(*BindListener)
|
||||
|
||||
conn1 := mockEndpointMgr.GetEndpoint(bindListener1.fakeIP)
|
||||
require.NotNil(t, conn1, "Peer1 endpoint should be registered")
|
||||
_, err = conn1.Write([]byte{0x01})
|
||||
require.NoError(t, err)
|
||||
|
||||
conn2 := mockEndpointMgr.GetEndpoint(bindListener2.fakeIP)
|
||||
require.NotNil(t, conn2, "Peer2 endpoint should be registered")
|
||||
_, err = conn2.Write([]byte{0x02})
|
||||
require.NoError(t, err)
|
||||
|
||||
receivedPeers := make(map[peerid.ConnID]bool)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
receivedPeers[peerConnID] = true
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notifications")
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, receivedPeers[cfg1.PeerConnID], "Peer1 activity should be received")
|
||||
assert.True(t, receivedPeers[cfg2.PeerConnID], "Peer2 activity should be received")
|
||||
}
|
||||
41
client/internal/lazyconn/activity/listener_test.go
Normal file
41
client/internal/lazyconn/activity/listener_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestNewListener(t *testing.T) {
|
||||
peer := &MocPeer{
|
||||
PeerID: "examplePublicKey1",
|
||||
}
|
||||
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
Log: log.WithField("peer", "examplePublicKey1"),
|
||||
}
|
||||
|
||||
l, err := NewListener(MocWGIface{}, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create listener: %v", err)
|
||||
}
|
||||
|
||||
chanClosed := make(chan struct{})
|
||||
go func() {
|
||||
defer close(chanClosed)
|
||||
l.ReadPackets()
|
||||
}()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
l.Close()
|
||||
|
||||
select {
|
||||
case <-chanClosed:
|
||||
case <-time.After(time.Second):
|
||||
}
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestUDPListener_Creation(t *testing.T) {
|
||||
mockIface := &MocWGIface{}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewUDPListener(mockIface, cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, listener.conn)
|
||||
require.NotNil(t, listener.endpoint)
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPListener_ActivityDetection(t *testing.T) {
|
||||
mockIface := &MocWGIface{}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewUDPListener(mockIface, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
activityDetected := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(activityDetected)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("udp", listener.conn.LocalAddr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write([]byte{0x01, 0x02, 0x03})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-activityDetected:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity detection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPListener_Close(t *testing.T) {
|
||||
mockIface := &MocWGIface{}
|
||||
|
||||
peer := &MocPeer{PeerID: "testPeer1"}
|
||||
cfg := lazyconn.PeerConfig{
|
||||
PublicKey: peer.PeerID,
|
||||
PeerConnID: peer.ConnID(),
|
||||
AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")},
|
||||
Log: log.WithField("peer", "testPeer1"),
|
||||
}
|
||||
|
||||
listener, err := NewUDPListener(mockIface, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
readPacketsDone := make(chan struct{})
|
||||
go func() {
|
||||
listener.ReadPackets()
|
||||
close(readPacketsDone)
|
||||
}()
|
||||
|
||||
listener.Close()
|
||||
|
||||
select {
|
||||
case <-readPacketsDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ReadPackets to exit after Close")
|
||||
}
|
||||
|
||||
assert.True(t, listener.isClosed.Load(), "Listener should be marked as closed")
|
||||
}
|
||||
@@ -1,32 +1,21 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
|
||||
// listener defines the contract for activity detection listeners.
|
||||
type listener interface {
|
||||
ReadPackets()
|
||||
Close()
|
||||
}
|
||||
|
||||
type WgInterface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -34,7 +23,7 @@ type Manager struct {
|
||||
|
||||
wgIface WgInterface
|
||||
|
||||
peers map[peerid.ConnID]listener
|
||||
peers map[peerid.ConnID]*Listener
|
||||
done chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
@@ -44,7 +33,7 @@ func NewManager(wgIface WgInterface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
wgIface: wgIface,
|
||||
peers: make(map[peerid.ConnID]listener),
|
||||
peers: make(map[peerid.ConnID]*Listener),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
return m
|
||||
@@ -59,38 +48,16 @@ func (m *Manager) MonitorPeerActivity(peerCfg lazyconn.PeerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
listener, err := m.createListener(peerCfg)
|
||||
listener, err := NewListener(m.wgIface, peerCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.peers[peerCfg.PeerConnID] = listener
|
||||
|
||||
go m.waitForTraffic(listener, peerCfg.PeerConnID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) {
|
||||
if !m.wgIface.IsUserspaceBind() {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
// BindListener is only used on Windows and JS platforms:
|
||||
// - JS: Cannot listen to UDP sockets
|
||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||
// gateway points to, preventing them from reaching the loopback interface.
|
||||
// BindListener bypasses this by passing data directly through the bind.
|
||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
||||
return NewUDPListener(m.wgIface, peerCfg)
|
||||
}
|
||||
|
||||
provider, ok := m.wgIface.(bindProvider)
|
||||
if !ok {
|
||||
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||
}
|
||||
|
||||
return NewBindListener(m.wgIface, provider.GetBind(), peerCfg)
|
||||
}
|
||||
|
||||
func (m *Manager) RemovePeer(log *log.Entry, peerConnID peerid.ConnID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -115,8 +82,8 @@ func (m *Manager) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
|
||||
l.ReadPackets()
|
||||
func (m *Manager) waitForTraffic(listener *Listener, peerConnID peerid.ConnID) {
|
||||
listener.ReadPackets()
|
||||
|
||||
m.mu.Lock()
|
||||
if _, ok := m.peers[peerConnID]; !ok {
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
)
|
||||
@@ -31,26 +30,16 @@ func (m MocWGIface) RemovePeer(string) error {
|
||||
|
||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (m MocWGIface) IsUserspaceBind() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m MocWGIface) Address() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerListener is a test helper to access listeners
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
||||
// Add this method to the Manager struct
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (*Listener, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
l, exists := m.peers[peerConnID]
|
||||
return l, exists
|
||||
listener, exists := m.peers[peerConnID]
|
||||
return listener, exists
|
||||
}
|
||||
|
||||
func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
@@ -76,12 +65,7 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
t.Fatalf("peer listener not found")
|
||||
}
|
||||
|
||||
// Get the UDP listener's address for triggering
|
||||
udpListener, ok := listener.(*UDPListener)
|
||||
if !ok {
|
||||
t.Fatalf("expected UDPListener")
|
||||
}
|
||||
if err := trigger(udpListener.conn.LocalAddr().String()); err != nil {
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
@@ -113,9 +97,7 @@ func TestManager_RemovePeerActivity(t *testing.T) {
|
||||
t.Fatalf("failed to monitor peer activity: %v", err)
|
||||
}
|
||||
|
||||
listener, _ := mgr.GetPeerListener(peerCfg1.PeerConnID)
|
||||
udpListener, _ := listener.(*UDPListener)
|
||||
addr := udpListener.conn.LocalAddr().String()
|
||||
addr := mgr.peers[peerCfg1.PeerConnID].conn.LocalAddr().String()
|
||||
|
||||
mgr.RemovePeer(peerCfg1.Log, peerCfg1.PeerConnID)
|
||||
|
||||
@@ -165,8 +147,7 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
t.Fatalf("peer listener for peer1 not found")
|
||||
}
|
||||
|
||||
udpListener1, _ := listener.(*UDPListener)
|
||||
if err := trigger(udpListener1.conn.LocalAddr().String()); err != nil {
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
@@ -175,8 +156,7 @@ func TestManager_MultiPeerActivity(t *testing.T) {
|
||||
t.Fatalf("peer listener for peer2 not found")
|
||||
}
|
||||
|
||||
udpListener2, _ := listener.(*UDPListener)
|
||||
if err := trigger(udpListener2.conn.LocalAddr().String()); err != nil {
|
||||
if err := trigger(listener.conn.LocalAddr().String()); err != nil {
|
||||
t.Fatalf("failed to trigger activity: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/monotime"
|
||||
)
|
||||
|
||||
@@ -15,6 +14,5 @@ type WGIface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
LastActivities() map[string]monotime.Time
|
||||
}
|
||||
|
||||
@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
|
||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||
|
||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||
if !isForceRelayed() {
|
||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||
}
|
||||
|
||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||
|
||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOfferChan := make(chan struct{})
|
||||
onNewOffeChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteOffer(OfferAnswer{
|
||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOfferChan:
|
||||
case <-onNewOffeChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
onNewOfferChan := make(chan struct{})
|
||||
onNewOffeChan := make(chan struct{})
|
||||
|
||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOfferChan <- struct{}{}
|
||||
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||
onNewOffeChan <- struct{}{}
|
||||
})
|
||||
|
||||
conn.OnRemoteAnswer(OfferAnswer{
|
||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-onNewOfferChan:
|
||||
case <-onNewOffeChan:
|
||||
// success
|
||||
case <-ctx.Done():
|
||||
t.Error("expected to receive a new offer notification, but timed out")
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
||||
)
|
||||
|
||||
func GetICEMonitorPeriod() time.Duration {
|
||||
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
||||
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
return defaultCandidatesMonitorPeriod
|
||||
}
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
candidatesMonitorPeriod = 5 * time.Minute
|
||||
candidateGatheringTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ICEMonitor struct {
|
||||
@@ -25,19 +25,16 @@ type ICEMonitor struct {
|
||||
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig icemaker.Config
|
||||
tickerPeriod time.Duration
|
||||
|
||||
currentCandidatesAddress []string
|
||||
candidatesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
||||
log.Debugf("prepare ICE monitor with period: %s", period)
|
||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||
cm := &ICEMonitor{
|
||||
ReconnectCh: make(chan struct{}, 1),
|
||||
iFaceDiscover: iFaceDiscover,
|
||||
iceConfig: config,
|
||||
tickerPeriod: period,
|
||||
}
|
||||
return cm
|
||||
}
|
||||
@@ -49,12 +46,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
||||
return
|
||||
}
|
||||
|
||||
// Initial check to populate the candidates for later comparison
|
||||
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
||||
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(cm.tickerPeriod)
|
||||
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
|
||||
@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.cancelIceMonitor = cancel
|
||||
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||
|
||||
@@ -44,19 +44,13 @@ type OfferAnswer struct {
|
||||
}
|
||||
|
||||
type Handshaker struct {
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
||||
// and it will only keep the latest message if multiple offers are received in a short time
|
||||
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
||||
// and also to avoid processing old offers if multiple offers are received in a short time
|
||||
// the listener will always process the latest offer
|
||||
relayListener *AsyncOfferListener
|
||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
||||
mu sync.Mutex
|
||||
log *log.Entry
|
||||
config ConnConfig
|
||||
signaler *Signaler
|
||||
ice *WorkerICE
|
||||
relay *WorkerRelay
|
||||
onNewOfferListeners []*OfferListener
|
||||
|
||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||
remoteOffersCh chan OfferAnswer
|
||||
@@ -76,39 +70,28 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.relayListener = NewAsyncOfferListener(offer)
|
||||
}
|
||||
|
||||
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
h.iceListener = offer
|
||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||
l := NewOfferListener(offer)
|
||||
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||
}
|
||||
|
||||
func (h *Handshaker) Listen(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
// received confirmation from the remote peer -> ready to proceed
|
||||
if err := h.sendAnswer(); err != nil {
|
||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||
continue
|
||||
}
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||
if h.relayListener != nil {
|
||||
h.relayListener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
|
||||
if h.iceListener != nil {
|
||||
h.iceListener(&remoteOfferAnswer)
|
||||
for _, listener := range h.onNewOfferListeners {
|
||||
listener.Notify(&remoteOfferAnswer)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
h.log.Infof("stop listening for remote offers and answers")
|
||||
|
||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
||||
return oa.SessionID.String()
|
||||
}
|
||||
|
||||
type AsyncOfferListener struct {
|
||||
type OfferListener struct {
|
||||
fn callbackFunc
|
||||
running bool
|
||||
latest *OfferAnswer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
||||
return &AsyncOfferListener{
|
||||
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||
return &OfferListener{
|
||||
fn: fn,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||
o.mu.Lock()
|
||||
defer o.mu.Unlock()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
||||
runChan <- struct{}{}
|
||||
}
|
||||
|
||||
hl := NewAsyncOfferListener(longRunningFn)
|
||||
hl := NewOfferListener(longRunningFn)
|
||||
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
hl.Notify(dummyOfferAnswer)
|
||||
|
||||
@@ -92,16 +92,23 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Lock()
|
||||
defer w.muxAgent.Unlock()
|
||||
|
||||
if w.agent != nil || w.agentConnecting {
|
||||
if w.agentConnecting {
|
||||
w.log.Debugf("agent connection is in progress, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
if w.agent != nil {
|
||||
// backward compatibility with old clients that do not send session ID
|
||||
if remoteOfferAnswer.SessionID == nil {
|
||||
w.log.Debugf("agent already exists, skipping the offer")
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.log.Debugf("agent already exists, recreate the connection")
|
||||
@@ -109,12 +116,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
if err := w.agent.Close(); err != nil {
|
||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||
}
|
||||
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
}
|
||||
|
||||
@@ -125,23 +126,18 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||
}
|
||||
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
||||
}
|
||||
w.log.Debugf("recreate ICE agent")
|
||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||
w.muxAgent.Unlock()
|
||||
return
|
||||
}
|
||||
w.agent = agent
|
||||
w.agentDialerCancel = dialerCancel
|
||||
w.agentConnecting = true
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
} else {
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||
}
|
||||
@@ -297,6 +293,9 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
||||
w.muxAgent.Lock()
|
||||
w.agentConnecting = false
|
||||
w.lastSuccess = time.Now()
|
||||
if remoteOfferAnswer.SessionID != nil {
|
||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
|
||||
// todo: the potential problem is a race between the onConnectionStateChange
|
||||
@@ -310,17 +309,16 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
||||
}
|
||||
|
||||
w.muxAgent.Lock()
|
||||
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
|
||||
if w.agent == agent {
|
||||
// consider to remove from here and move to the OnNewOffer
|
||||
sessionID, err := NewICESessionID()
|
||||
if err != nil {
|
||||
w.log.Errorf("failed to create new session ID: %s", err)
|
||||
}
|
||||
w.sessionID = sessionID
|
||||
w.agent = nil
|
||||
w.agentConnecting = false
|
||||
w.remoteSessionID = ""
|
||||
}
|
||||
w.muxAgent.Unlock()
|
||||
}
|
||||
@@ -397,12 +395,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
|
||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||
w.conn.onICEStateDisconnected()
|
||||
}
|
||||
w.closeAgent(agent, dialerCancel)
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
@@ -195,7 +195,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config := &Config{
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
WgPort: iface.DefaultWgPort,
|
||||
}
|
||||
|
||||
if _, err := config.apply(input); err != nil {
|
||||
|
||||
@@ -5,14 +5,11 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -144,95 +141,6 @@ func TestHiddenPreSharedKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProfileDefaults(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
require.NoError(t, err, "should create new config")
|
||||
|
||||
assert.Equal(t, DefaultManagementURL, config.ManagementURL.String(), "ManagementURL should have default")
|
||||
assert.Equal(t, DefaultAdminURL, config.AdminURL.String(), "AdminURL should have default")
|
||||
assert.NotEmpty(t, config.PrivateKey, "PrivateKey should be generated")
|
||||
assert.NotEmpty(t, config.SSHKey, "SSHKey should be generated")
|
||||
assert.Equal(t, iface.WgInterfaceDefault, config.WgIface, "WgIface should have default")
|
||||
assert.Equal(t, iface.DefaultWgPort, config.WgPort, "WgPort should default to 51820")
|
||||
assert.Equal(t, uint16(iface.DefaultMTU), config.MTU, "MTU should have default")
|
||||
assert.Equal(t, dynamic.DefaultInterval, config.DNSRouteInterval, "DNSRouteInterval should have default")
|
||||
assert.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set")
|
||||
assert.NotNil(t, config.DisableNotifications, "DisableNotifications should be set")
|
||||
assert.NotEmpty(t, config.IFaceBlackList, "IFaceBlackList should have defaults")
|
||||
|
||||
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
|
||||
assert.NotNil(t, config.NetworkMonitor, "NetworkMonitor should be set on Windows/macOS")
|
||||
assert.True(t, *config.NetworkMonitor, "NetworkMonitor should be enabled by default on Windows/macOS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWireguardPortZeroExplicit(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
// Create a new profile with explicit port 0 (random port)
|
||||
explicitZero := 0
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
WireguardPort: &explicitZero,
|
||||
})
|
||||
require.NoError(t, err, "should create config with explicit port 0")
|
||||
|
||||
assert.Equal(t, 0, config.WgPort, "WgPort should be 0 when explicitly set by user")
|
||||
|
||||
// Verify it persists
|
||||
readConfig, err := GetConfig(configPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, readConfig.WgPort, "WgPort should remain 0 after reading from file")
|
||||
}
|
||||
|
||||
func TestWireguardPortDefaultVsExplicit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
wireguardPort *int
|
||||
expectedPort int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "no port specified uses default",
|
||||
wireguardPort: nil,
|
||||
expectedPort: iface.DefaultWgPort,
|
||||
description: "When user doesn't specify port, default to 51820",
|
||||
},
|
||||
{
|
||||
name: "explicit zero for random port",
|
||||
wireguardPort: func() *int { v := 0; return &v }(),
|
||||
expectedPort: 0,
|
||||
description: "When user explicitly sets 0, use 0 for random port",
|
||||
},
|
||||
{
|
||||
name: "explicit custom port",
|
||||
wireguardPort: func() *int { v := 52000; return &v }(),
|
||||
expectedPort: 52000,
|
||||
description: "When user sets custom port, use that port",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
configPath := filepath.Join(tempDir, "config.json")
|
||||
|
||||
config, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
WireguardPort: tt.wireguardPort,
|
||||
})
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedPort, config.WgPort, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOldManagementURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -353,13 +353,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.CustomDNSAddress = []byte{}
|
||||
}
|
||||
|
||||
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
interval := msg.DnsRouteInterval.AsDuration()
|
||||
config.DNSRouteInterval = &interval
|
||||
}
|
||||
|
||||
config.RosenpassEnabled = msg.RosenpassEnabled
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// TestSetConfig_AllFieldsSaved ensures that all fields in SetConfigRequest are properly saved to the config.
|
||||
// This test uses reflection to detect when new fields are added but not handled in SetConfig.
|
||||
func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
profName := "test-profile"
|
||||
|
||||
ic := profilemanager.ConfigInput{
|
||||
ConfigPath: filepath.Join(tempDir, profName+".json"),
|
||||
ManagementURL: "https://api.netbird.io:443",
|
||||
}
|
||||
_, err = profilemanager.UpdateOrCreateConfig(ic)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
s := New(ctx, "console", "", false, false)
|
||||
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
disableAutoConnect := true
|
||||
networkMonitor := true
|
||||
disableClientRoutes := true
|
||||
disableServerRoutes := true
|
||||
disableDNS := true
|
||||
disableFirewall := true
|
||||
blockLANAccess := true
|
||||
disableNotifications := true
|
||||
lazyConnectionEnabled := true
|
||||
blockInbound := true
|
||||
mtu := int64(1280)
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: currUser.Username,
|
||||
ManagementUrl: "https://new-api.netbird.io:443",
|
||||
AdminURL: "https://new-admin.netbird.io",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
DisableAutoConnect: &disableAutoConnect,
|
||||
NetworkMonitor: &networkMonitor,
|
||||
DisableClientRoutes: &disableClientRoutes,
|
||||
DisableServerRoutes: &disableServerRoutes,
|
||||
DisableDns: &disableDNS,
|
||||
DisableFirewall: &disableFirewall,
|
||||
BlockLanAccess: &blockLANAccess,
|
||||
DisableNotifications: &disableNotifications,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||
DnsLabels: []string{"label1", "label2"},
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
}
|
||||
|
||||
_, err = s.SetConfig(ctx, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}
|
||||
cfgPath, err := profState.FilePath()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg, err := profilemanager.GetConfig(cfgPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "https://new-api.netbird.io:443", cfg.ManagementURL.String())
|
||||
require.Equal(t, "https://new-admin.netbird.io:443", cfg.AdminURL.String())
|
||||
require.Equal(t, rosenpassEnabled, cfg.RosenpassEnabled)
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
require.Equal(t, disableAutoConnect, cfg.DisableAutoConnect)
|
||||
require.NotNil(t, cfg.NetworkMonitor)
|
||||
require.Equal(t, networkMonitor, *cfg.NetworkMonitor)
|
||||
require.Equal(t, disableClientRoutes, cfg.DisableClientRoutes)
|
||||
require.Equal(t, disableServerRoutes, cfg.DisableServerRoutes)
|
||||
require.Equal(t, disableDNS, cfg.DisableDNS)
|
||||
require.Equal(t, disableFirewall, cfg.DisableFirewall)
|
||||
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
|
||||
require.NotNil(t, cfg.DisableNotifications)
|
||||
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
|
||||
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
|
||||
require.Equal(t, blockInbound, cfg.BlockInbound)
|
||||
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
|
||||
require.Equal(t, "1.1.1.1:53", cfg.CustomDNSAddress)
|
||||
// IFaceBlackList contains defaults + extras
|
||||
require.Contains(t, cfg.IFaceBlackList, "eth1")
|
||||
require.Contains(t, cfg.IFaceBlackList, "eth2")
|
||||
require.Equal(t, []string{"label1", "label2"}, cfg.DNSLabels.ToPunycodeList())
|
||||
require.Equal(t, 2*time.Minute, cfg.DNSRouteInterval)
|
||||
require.Equal(t, uint16(mtu), cfg.MTU)
|
||||
|
||||
verifyAllFieldsCovered(t, req)
|
||||
}
|
||||
|
||||
// verifyAllFieldsCovered uses reflection to ensure we're testing all fields in SetConfigRequest.
|
||||
// If a new field is added to SetConfigRequest, this function will fail the test,
|
||||
// forcing the developer to update both the SetConfig handler and this test.
|
||||
func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
t.Helper()
|
||||
|
||||
metadataFields := map[string]bool{
|
||||
"state": true, // protobuf internal
|
||||
"sizeCache": true, // protobuf internal
|
||||
"unknownFields": true, // protobuf internal
|
||||
"Username": true, // metadata
|
||||
"ProfileName": true, // metadata
|
||||
"CleanNATExternalIPs": true, // control flag for clearing
|
||||
"CleanDNSLabels": true, // control flag for clearing
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
"ManagementUrl": true,
|
||||
"AdminURL": true,
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
"DisableAutoConnect": true,
|
||||
"NetworkMonitor": true,
|
||||
"DisableClientRoutes": true,
|
||||
"DisableServerRoutes": true,
|
||||
"DisableDns": true,
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"NatExternalIPs": true,
|
||||
"CustomDNSAddress": true,
|
||||
"ExtraIFaceBlacklist": true,
|
||||
"DnsLabels": true,
|
||||
"DnsRouteInterval": true,
|
||||
"Mtu": true,
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
typ := val.Type()
|
||||
|
||||
var unexpectedFields []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldName := field.Name
|
||||
|
||||
if metadataFields[fieldName] {
|
||||
continue
|
||||
}
|
||||
|
||||
if !expectedFields[fieldName] {
|
||||
unexpectedFields = append(unexpectedFields, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unexpectedFields) > 0 {
|
||||
t.Fatalf("New field(s) detected in SetConfigRequest: %v", unexpectedFields)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCLIFlags_MappedToSetConfig ensures all CLI flags that modify config are properly mapped to SetConfigRequest.
|
||||
// This test catches bugs where a new CLI flag is added but not wired to the SetConfigRequest in setupSetConfigReq.
|
||||
func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
// Map of CLI flag names to their corresponding SetConfigRequest field names.
|
||||
// This map must be updated when adding new config-related CLI flags.
|
||||
flagToField := map[string]string{
|
||||
"management-url": "ManagementUrl",
|
||||
"admin-url": "AdminURL",
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
"disable-auto-connect": "DisableAutoConnect",
|
||||
"network-monitor": "NetworkMonitor",
|
||||
"disable-client-routes": "DisableClientRoutes",
|
||||
"disable-server-routes": "DisableServerRoutes",
|
||||
"disable-dns": "DisableDns",
|
||||
"disable-firewall": "DisableFirewall",
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
"extra-dns-labels": "DnsLabels",
|
||||
"dns-router-interval": "DnsRouteInterval",
|
||||
"mtu": "Mtu",
|
||||
}
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
fieldsWithoutCLIFlags := map[string]bool{
|
||||
"DisableNotifications": true, // Only settable via UI
|
||||
}
|
||||
|
||||
// Get all SetConfigRequest fields to verify our map is complete.
|
||||
req := &proto.SetConfigRequest{}
|
||||
val := reflect.ValueOf(req).Elem()
|
||||
typ := val.Type()
|
||||
|
||||
var unmappedFields []string
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldName := field.Name
|
||||
|
||||
// Skip protobuf internal fields and metadata fields.
|
||||
if fieldName == "state" || fieldName == "sizeCache" || fieldName == "unknownFields" {
|
||||
continue
|
||||
}
|
||||
if fieldName == "Username" || fieldName == "ProfileName" {
|
||||
continue
|
||||
}
|
||||
if fieldName == "CleanNATExternalIPs" || fieldName == "CleanDNSLabels" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field is either mapped to a CLI flag or explicitly documented as having no CLI flag.
|
||||
mappedToCLI := false
|
||||
for _, mappedField := range flagToField {
|
||||
if mappedField == fieldName {
|
||||
mappedToCLI = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
hasNoCLIFlag := fieldsWithoutCLIFlags[fieldName]
|
||||
|
||||
if !mappedToCLI && !hasNoCLIFlag {
|
||||
unmappedFields = append(unmappedFields, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
if len(unmappedFields) > 0 {
|
||||
t.Fatalf("SetConfigRequest field(s) not documented: %v\n"+
|
||||
"Either add the CLI flag to flagToField map, or if there's no CLI flag for this field, "+
|
||||
"add it to fieldsWithoutCLIFlags map with a comment explaining why.", unmappedFields)
|
||||
}
|
||||
|
||||
t.Log("All SetConfigRequest fields are properly documented")
|
||||
}
|
||||
@@ -205,18 +205,15 @@ func mapPeers(
|
||||
localICEEndpoint := ""
|
||||
remoteICEEndpoint := ""
|
||||
relayServerAddress := ""
|
||||
connType := "-"
|
||||
connType := "P2P"
|
||||
lastHandshake := time.Time{}
|
||||
transferReceived := int64(0)
|
||||
transferSent := int64(0)
|
||||
|
||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
||||
|
||||
if isPeerConnected {
|
||||
connType = "P2P"
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
if pbPeerState.Relayed {
|
||||
connType = "Relayed"
|
||||
}
|
||||
|
||||
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"fyne.io/systray"
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@@ -632,7 +633,7 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
err := openURL(loginResp.VerificationURIComplete)
|
||||
err := open.Run(loginResp.VerificationURIComplete)
|
||||
if err != nil {
|
||||
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
||||
return err
|
||||
@@ -1486,10 +1487,6 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
}
|
||||
|
||||
func openURL(url string) error {
|
||||
if browser := os.Getenv("BROWSER"); browser != "" {
|
||||
return exec.Command(browser, url).Start()
|
||||
}
|
||||
|
||||
var err error
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
|
||||
@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
|
||||
config := &tls.Config{
|
||||
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true, // We'll validate manually after handshake
|
||||
VerifyConnection: func(cs tls.ConnectionState) error {
|
||||
var certChain [][]byte
|
||||
@@ -93,15 +93,4 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, req
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
|
||||
if requiresCredSSP {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS12
|
||||
} else {
|
||||
config.MinVersion = tls.VersionTLS12
|
||||
config.MaxVersion = tls.VersionTLS13
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -6,13 +6,11 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -21,34 +19,18 @@ const (
|
||||
RDCleanPathVersion = 3390
|
||||
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
|
||||
RDCleanPathProxyScheme = "ws"
|
||||
|
||||
rdpDialTimeout = 15 * time.Second
|
||||
|
||||
GeneralErrorCode = 1
|
||||
WSAETimedOut = 10060
|
||||
WSAEConnRefused = 10061
|
||||
WSAEConnAborted = 10053
|
||||
WSAEConnReset = 10054
|
||||
WSAEGenericError = 10050
|
||||
)
|
||||
|
||||
type RDCleanPathPDU struct {
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathErr struct {
|
||||
ErrorCode int16 `asn1:"tag:0,explicit"`
|
||||
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
|
||||
WSALastError int16 `asn1:"tag:2,explicit,optional"`
|
||||
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
|
||||
Version int64 `asn1:"tag:0,explicit"`
|
||||
Error []byte `asn1:"tag:1,explicit,optional"`
|
||||
Destination string `asn1:"utf8,tag:2,explicit,optional"`
|
||||
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
|
||||
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
|
||||
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
|
||||
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
|
||||
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
|
||||
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
|
||||
}
|
||||
|
||||
type RDCleanPathProxy struct {
|
||||
@@ -228,13 +210,9 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
destination := conn.destination
|
||||
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
|
||||
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
conn.rdpConn = rdpConn
|
||||
@@ -242,7 +220,6 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
_, err = rdpConn.Write(firstPacket)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write first packet: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -250,7 +227,6 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
|
||||
n, err := rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -293,52 +269,3 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func errorToWSACode(err error) int16 {
|
||||
if err == nil {
|
||||
return WSAEGenericError
|
||||
}
|
||||
var netErr *net.OpError
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WSAETimedOut
|
||||
}
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return WSAEConnAborted
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return WSAEConnReset
|
||||
}
|
||||
return WSAEGenericError
|
||||
}
|
||||
|
||||
func newWSAError(err error) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
WSALastError: errorToWSACode(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPError(statusCode int16) RDCleanPathPDU {
|
||||
return RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: RDCleanPathErr{
|
||||
ErrorCode: GeneralErrorCode,
|
||||
HTTPStatusCode: statusCode,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package rdp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/asn1"
|
||||
"io"
|
||||
@@ -12,17 +11,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
|
||||
protocolSSL = 0x00000001
|
||||
protocolHybridEx = 0x00000008
|
||||
)
|
||||
|
||||
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
|
||||
|
||||
if pdu.Version != RDCleanPathVersion {
|
||||
p.sendRDCleanPathError(conn, newHTTPError(400))
|
||||
p.sendRDCleanPathError(conn, "Unsupported version")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -31,13 +24,10 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
destination = pdu.Destination
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
|
||||
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to connect to %s: %v", destination, err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.sendRDCleanPathError(conn, "Connection failed")
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
@@ -50,34 +40,6 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
|
||||
p.setupTLSConnection(conn, pdu)
|
||||
}
|
||||
|
||||
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
|
||||
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
|
||||
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
|
||||
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
|
||||
const minResponseLength = 19
|
||||
|
||||
if len(x224Response) < minResponseLength {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
// Per X.224 specification:
|
||||
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
|
||||
// x224Response[5] == 0xD0: X.224 Data TPDU code
|
||||
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
if x224Response[11] == 0x02 {
|
||||
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
|
||||
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
|
||||
|
||||
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
|
||||
return hasNLA, flags, true
|
||||
}
|
||||
|
||||
return false, 0, false
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
var x224Response []byte
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
@@ -85,7 +47,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -93,32 +55,21 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
x224Response = response[:n]
|
||||
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
|
||||
}
|
||||
|
||||
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
|
||||
if detected {
|
||||
if requiresCredSSP {
|
||||
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
|
||||
} else {
|
||||
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
|
||||
}
|
||||
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
|
||||
tlsConfig := p.getTLSConfigWithValidation(conn)
|
||||
|
||||
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
|
||||
conn.tlsConn = tlsConn
|
||||
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
log.Errorf("TLS handshake failed: %v", err)
|
||||
p.sendRDCleanPathError(conn, newWSAError(err))
|
||||
p.sendRDCleanPathError(conn, "TLS handshake failed")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -155,6 +106,47 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
if len(pdu.X224ConnectionPDU) > 0 {
|
||||
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
|
||||
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to write X.224 PDU: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to forward X.224")
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1024)
|
||||
n, err := conn.rdpConn.Read(response)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read X.224 response: %v", err)
|
||||
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
|
||||
return
|
||||
}
|
||||
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
X224ConnectionPDU: response[:n],
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
} else {
|
||||
responsePDU := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
ServerAddr: conn.destination,
|
||||
}
|
||||
p.sendRDCleanPathPDU(conn, responsePDU)
|
||||
}
|
||||
|
||||
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
|
||||
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
|
||||
|
||||
<-conn.ctx.Done()
|
||||
log.Debug("TCP connection context done, cleaning up")
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
@@ -166,6 +158,21 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
|
||||
pdu := RDCleanPathPDU{
|
||||
Version: RDCleanPathVersion,
|
||||
Error: []byte(errorMsg),
|
||||
}
|
||||
|
||||
data, err := asn1.Marshal(pdu)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to marshal error PDU: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
p.sendToWebSocket(conn, data)
|
||||
}
|
||||
|
||||
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
|
||||
msgChan := make(chan []byte)
|
||||
errChan := make(chan error)
|
||||
|
||||
2
go.mod
2
go.mod
@@ -62,7 +62,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
|
||||
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -49,7 +49,6 @@ services:
|
||||
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
|
||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||
- traefik.http.routers.netbird-signal.service=netbird-signal
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||
|
||||
|
||||
@@ -682,6 +682,17 @@ renderManagementJson() {
|
||||
"URI": "stun:$NETBIRD_DOMAIN:3478"
|
||||
}
|
||||
],
|
||||
"TURNConfig": {
|
||||
"Turns": [
|
||||
{
|
||||
"Proto": "udp",
|
||||
"URI": "turn:$NETBIRD_DOMAIN:3478",
|
||||
"Username": "$TURN_USER",
|
||||
"Password": "$TURN_PASSWORD"
|
||||
}
|
||||
],
|
||||
"TimeBasedCredentials": false
|
||||
},
|
||||
"Relay": {
|
||||
"Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
|
||||
"CredentialsTTL": "24h",
|
||||
|
||||
@@ -7,10 +7,8 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -46,9 +44,6 @@ import (
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -68,9 +63,6 @@ type GRPCServer struct {
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -94,7 +86,7 @@ func NewServer(
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(len(peersUpdateManager.peerChannels))
|
||||
return int64(peersUpdateManager.GetChannelCount())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -104,16 +96,6 @@ func NewServer(
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -128,8 +110,6 @@ func NewServer(
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -171,11 +151,6 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
@@ -183,7 +158,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
@@ -198,7 +172,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
@@ -223,10 +196,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -242,14 +213,12 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -268,8 +237,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,9 @@ type mockHTTPClient struct {
|
||||
}
|
||||
|
||||
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
||||
if req.Body != nil {
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err == nil {
|
||||
c.reqBody = string(body)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: c.code,
|
||||
|
||||
@@ -201,12 +201,6 @@ func NewManager(ctx context.Context, config Config, appMetrics telemetry.AppMetr
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
}
|
||||
return NewJumpCloudManager(jumpcloudConfig, appMetrics)
|
||||
case "pocketid":
|
||||
pocketidConfig := PocketIdClientConfig{
|
||||
APIToken: config.ExtraConfig["ApiToken"],
|
||||
ManagementEndpoint: config.ExtraConfig["ManagementEndpoint"],
|
||||
}
|
||||
return NewPocketIdManager(pocketidConfig, appMetrics)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid manager type: %s", config.ManagerType)
|
||||
}
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
type PocketIdManager struct {
|
||||
managementEndpoint string
|
||||
apiToken string
|
||||
httpClient ManagerHTTPClient
|
||||
credentials ManagerCredentials
|
||||
helper ManagerHelper
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
type pocketIdCustomClaimDto struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type pocketIdUserDto struct {
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
Disabled bool `json:"disabled"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
ID string `json:"id"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
LastName string `json:"lastName"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Locale string `json:"locale"`
|
||||
UserGroups []pocketIdUserGroupDto `json:"userGroups"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdUserCreateDto struct {
|
||||
Disabled bool `json:"disabled,omitempty"`
|
||||
DisplayName string `json:"displayName"`
|
||||
Email string `json:"email"`
|
||||
FirstName string `json:"firstName"`
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
LastName string `json:"lastName,omitempty"`
|
||||
Locale string `json:"locale,omitempty"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type pocketIdPaginatedUserDto struct {
|
||||
Data []pocketIdUserDto `json:"data"`
|
||||
Pagination pocketIdPaginationDto `json:"pagination"`
|
||||
}
|
||||
|
||||
type pocketIdPaginationDto struct {
|
||||
CurrentPage int `json:"currentPage"`
|
||||
ItemsPerPage int `json:"itemsPerPage"`
|
||||
TotalItems int `json:"totalItems"`
|
||||
TotalPages int `json:"totalPages"`
|
||||
}
|
||||
|
||||
func (p *pocketIdUserDto) userData() *UserData {
|
||||
return &UserData{
|
||||
Email: p.Email,
|
||||
Name: p.DisplayName,
|
||||
ID: p.ID,
|
||||
AppMetadata: AppMetadata{},
|
||||
}
|
||||
}
|
||||
|
||||
type pocketIdUserGroupDto struct {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
CustomClaims []pocketIdCustomClaimDto `json:"customClaims"`
|
||||
FriendlyName string `json:"friendlyName"`
|
||||
ID string `json:"id"`
|
||||
LdapID string `json:"ldapId"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMetrics) (*PocketIdManager, error) {
|
||||
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
httpTransport.MaxIdleConns = 5
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: httpTransport,
|
||||
}
|
||||
helper := JsonParser{}
|
||||
|
||||
if config.ManagementEndpoint == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, ManagementEndpoint is missing")
|
||||
}
|
||||
|
||||
if config.APIToken == "" {
|
||||
return nil, fmt.Errorf("pocketId IdP configuration is incomplete, APIToken is missing")
|
||||
}
|
||||
|
||||
credentials := &PocketIdCredentials{
|
||||
clientConfig: config,
|
||||
httpClient: httpClient,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}
|
||||
|
||||
return &PocketIdManager{
|
||||
managementEndpoint: config.ManagementEndpoint,
|
||||
apiToken: config.APIToken,
|
||||
httpClient: httpClient,
|
||||
credentials: credentials,
|
||||
helper: helper,
|
||||
appMetrics: appMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) request(ctx context.Context, method, resource string, query *url.Values, body string) ([]byte, error) {
|
||||
var MethodsWithBody = []string{http.MethodPost, http.MethodPut}
|
||||
if !slices.Contains(MethodsWithBody, method) && body != "" {
|
||||
return nil, fmt.Errorf("Body provided to unsupported method: %s", method)
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/api/%s", p.managementEndpoint, resource)
|
||||
if query != nil {
|
||||
reqURL = fmt.Sprintf("%s?%s", reqURL, query.Encode())
|
||||
}
|
||||
var req *http.Request
|
||||
var err error
|
||||
if body != "" {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, strings.NewReader(body))
|
||||
} else {
|
||||
req, err = http.NewRequestWithContext(ctx, method, reqURL, nil)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("X-API-KEY", p.apiToken)
|
||||
|
||||
if body != "" {
|
||||
req.Header.Add("content-type", "application/json")
|
||||
req.Header.Add("content-length", fmt.Sprintf("%d", req.ContentLength))
|
||||
}
|
||||
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("received unexpected status code from PocketID API: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// getAllUsersPaginated fetches all users from PocketID API using pagination
|
||||
func (p *PocketIdManager) getAllUsersPaginated(ctx context.Context, searchParams url.Values) ([]pocketIdUserDto, error) {
|
||||
var allUsers []pocketIdUserDto
|
||||
currentPage := 1
|
||||
|
||||
for {
|
||||
params := url.Values{}
|
||||
// Copy existing search parameters
|
||||
for key, values := range searchParams {
|
||||
params[key] = values
|
||||
}
|
||||
|
||||
params.Set("pagination[limit]", "100")
|
||||
params.Set("pagination[page]", fmt.Sprintf("%d", currentPage))
|
||||
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var profiles pocketIdPaginatedUserDto
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allUsers = append(allUsers, profiles.Data...)
|
||||
|
||||
// Check if we've reached the last page
|
||||
if currentPage >= profiles.Pagination.TotalPages {
|
||||
break
|
||||
}
|
||||
|
||||
currentPage++
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) UpdateUserAppMetadata(_ context.Context, _ string, _ AppMetadata) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserDataByID(ctx context.Context, userId string, appMetadata AppMetadata) (*UserData, error) {
|
||||
body, err := p.request(ctx, http.MethodGet, "users/"+userId, nil, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserDataByID()
|
||||
}
|
||||
|
||||
var user pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userData := user.userData()
|
||||
userData.AppMetadata = appMetadata
|
||||
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAccount(ctx context.Context, accountId string) ([]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAccount()
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
userData.AppMetadata.WTAccountID = accountId
|
||||
|
||||
users = append(users, userData)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetAllAccounts(ctx context.Context) (map[string][]*UserData, error) {
|
||||
// Get all users using pagination
|
||||
allUsers, err := p.getAllUsersPaginated(ctx, url.Values{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetAllAccounts()
|
||||
}
|
||||
|
||||
indexedUsers := make(map[string][]*UserData)
|
||||
for _, profile := range allUsers {
|
||||
userData := profile.userData()
|
||||
indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], userData)
|
||||
}
|
||||
|
||||
return indexedUsers, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) {
|
||||
firstLast := strings.Split(name, " ")
|
||||
|
||||
createUser := pocketIdUserCreateDto{
|
||||
Disabled: false,
|
||||
DisplayName: name,
|
||||
Email: email,
|
||||
FirstName: firstLast[0],
|
||||
LastName: firstLast[1],
|
||||
Username: firstLast[0] + "." + firstLast[1],
|
||||
}
|
||||
payload, err := p.helper.Marshal(createUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := p.request(ctx, http.MethodPost, "users", nil, string(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var newUser pocketIdUserDto
|
||||
err = p.helper.Unmarshal(body, &newUser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountCreateUser()
|
||||
}
|
||||
var pending bool = true
|
||||
ret := &UserData{
|
||||
Email: email,
|
||||
Name: name,
|
||||
ID: newUser.ID,
|
||||
AppMetadata: AppMetadata{
|
||||
WTAccountID: accountID,
|
||||
WTPendingInvite: &pending,
|
||||
WTInvitedBy: invitedByEmail,
|
||||
},
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) GetUserByEmail(ctx context.Context, email string) ([]*UserData, error) {
|
||||
params := url.Values{
|
||||
// This value a
|
||||
"search": []string{email},
|
||||
}
|
||||
body, err := p.request(ctx, http.MethodGet, "users", ¶ms, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountGetUserByEmail()
|
||||
}
|
||||
|
||||
var profiles struct{ data []pocketIdUserDto }
|
||||
err = p.helper.Unmarshal(body, &profiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*UserData, 0)
|
||||
for _, profile := range profiles.data {
|
||||
users = append(users, profile.userData())
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) InviteUserByID(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodPut, "users/"+userID+"/one-time-access-email", nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PocketIdManager) DeleteUser(ctx context.Context, userID string) error {
|
||||
_, err := p.request(ctx, http.MethodDelete, "users/"+userID, nil, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if p.appMetrics != nil {
|
||||
p.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ Manager = (*PocketIdManager)(nil)
|
||||
|
||||
type PocketIdClientConfig struct {
|
||||
APIToken string
|
||||
ManagementEndpoint string
|
||||
}
|
||||
|
||||
type PocketIdCredentials struct {
|
||||
clientConfig PocketIdClientConfig
|
||||
helper ManagerHelper
|
||||
httpClient ManagerHTTPClient
|
||||
appMetrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
var _ ManagerCredentials = (*PocketIdCredentials)(nil)
|
||||
|
||||
func (p PocketIdCredentials) Authenticate(_ context.Context) (JWTToken, error) {
|
||||
return JWTToken{}, nil
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package idp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
|
||||
func TestNewPocketIdManager(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
inputConfig PocketIdClientConfig
|
||||
assertErrFunc require.ErrorAssertionFunc
|
||||
assertErrFuncMessage string
|
||||
}
|
||||
|
||||
defaultTestConfig := PocketIdClientConfig{
|
||||
APIToken: "api_token",
|
||||
ManagementEndpoint: "http://localhost",
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
{
|
||||
name: "Good Configuration",
|
||||
inputConfig: defaultTestConfig,
|
||||
assertErrFunc: require.NoError,
|
||||
assertErrFuncMessage: "shouldn't return error",
|
||||
},
|
||||
{
|
||||
name: "Missing ManagementEndpoint",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: defaultTestConfig.APIToken,
|
||||
ManagementEndpoint: "",
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
{
|
||||
name: "Missing APIToken",
|
||||
inputConfig: PocketIdClientConfig{
|
||||
APIToken: "",
|
||||
ManagementEndpoint: defaultTestConfig.ManagementEndpoint,
|
||||
},
|
||||
assertErrFunc: require.Error,
|
||||
assertErrFuncMessage: "should return error when field empty",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := NewPocketIdManager(tc.inputConfig, &telemetry.MockAppMetrics{})
|
||||
tc.assertErrFunc(t, err, tc.assertErrFuncMessage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPocketID_GetUserDataByID(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"id":"u1","email":"user1@example.com","displayName":"User One"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
md := AppMetadata{WTAccountID: "acc1"}
|
||||
got, err := mgr.GetUserDataByID(context.Background(), "u1", md)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "u1", got.ID)
|
||||
assert.Equal(t, "user1@example.com", got.Email)
|
||||
assert.Equal(t, "User One", got.Name)
|
||||
assert.Equal(t, "acc1", got.AppMetadata.WTAccountID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAccount_WithPagination(t *testing.T) {
|
||||
// Single page response with two users
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
users, err := mgr.GetAccount(context.Background(), "accX")
|
||||
require.NoError(t, err)
|
||||
require.Len(t, users, 2)
|
||||
assert.Equal(t, "u1", users[0].ID)
|
||||
assert.Equal(t, "accX", users[0].AppMetadata.WTAccountID)
|
||||
assert.Equal(t, "u2", users[1].ID)
|
||||
}
|
||||
|
||||
func TestPocketID_GetAllAccounts_WithPagination(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 200, resBody: `{"data":[{"id":"u1","email":"e1","displayName":"n1"},{"id":"u2","email":"e2","displayName":"n2"}],"pagination":{"currentPage":1,"itemsPerPage":100,"totalItems":2,"totalPages":1}}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
accounts, err := mgr.GetAllAccounts(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, accounts[UnsetAccountID], 2)
|
||||
}
|
||||
|
||||
func TestPocketID_CreateUser(t *testing.T) {
|
||||
client := &mockHTTPClient{code: 201, resBody: `{"id":"newid","email":"new@example.com","displayName":"New User"}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
ud, err := mgr.CreateUser(context.Background(), "new@example.com", "New User", "acc1", "inviter@example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "newid", ud.ID)
|
||||
assert.Equal(t, "new@example.com", ud.Email)
|
||||
assert.Equal(t, "New User", ud.Name)
|
||||
assert.Equal(t, "acc1", ud.AppMetadata.WTAccountID)
|
||||
if assert.NotNil(t, ud.AppMetadata.WTPendingInvite) {
|
||||
assert.True(t, *ud.AppMetadata.WTPendingInvite)
|
||||
}
|
||||
assert.Equal(t, "inviter@example.com", ud.AppMetadata.WTInvitedBy)
|
||||
}
|
||||
|
||||
func TestPocketID_InviteAndDeleteUser(t *testing.T) {
|
||||
// Same mock for both calls; returns OK with empty JSON
|
||||
client := &mockHTTPClient{code: 200, resBody: `{}`}
|
||||
|
||||
mgr, err := NewPocketIdManager(PocketIdClientConfig{APIToken: "tok", ManagementEndpoint: "http://localhost"}, nil)
|
||||
require.NoError(t, err)
|
||||
mgr.httpClient = client
|
||||
|
||||
err = mgr.InviteUserByID(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.DeleteUser(context.Background(), "u1")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
|
||||
return validatedPeers, nil
|
||||
}
|
||||
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
|
||||
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
|
||||
return peer
|
||||
}
|
||||
|
||||
|
||||
@@ -3,16 +3,16 @@ package integrated_validator
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// IntegratedValidator interface exists to avoid the circle dependencies
|
||||
type IntegratedValidator interface {
|
||||
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
|
||||
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
|
||||
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
|
||||
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
|
||||
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
|
||||
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error
|
||||
|
||||
@@ -350,6 +350,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -362,6 +363,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete peer: %w", err)
|
||||
@@ -381,7 +387,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if userID != activity.SystemInitiator {
|
||||
if updateAccountPeers && userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -578,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
}
|
||||
}
|
||||
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
|
||||
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
|
||||
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -678,6 +684,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
|
||||
}
|
||||
|
||||
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
if newPeer == nil {
|
||||
return nil, nil, nil, fmt.Errorf("new peer is nil")
|
||||
}
|
||||
@@ -690,7 +701,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
if updateAccountPeers {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
@@ -1257,12 +1270,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
wg.Wait()
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
@@ -1368,7 +1379,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
@@ -1514,6 +1525,16 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
|
||||
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
|
||||
}
|
||||
|
||||
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
||||
// Returns a slice of functions to save events after successful peer deletion.
|
||||
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
|
||||
@@ -1580,7 +1601,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &types.NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
peerDeletedEvents = append(peerDeletedEvents, func() {
|
||||
|
||||
@@ -1043,8 +1043,8 @@ func TestUpdateAccountPeers(t *testing.T) {
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
// assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
// assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("adding peer to unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg) //
|
||||
peerShouldNotReceiveUpdate(t, updMsg) //
|
||||
close(done)
|
||||
}()
|
||||
|
||||
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("deleting peer with unlinked group", func(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, updMsg)
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
|
||||
@@ -7,23 +7,25 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *types.NetworkMap
|
||||
Update *proto.SyncResponse
|
||||
}
|
||||
|
||||
type peerUpdate struct {
|
||||
mu sync.Mutex
|
||||
message *UpdateMessage
|
||||
notify chan struct{}
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.RWMutex
|
||||
// latestUpdates stores the latest update message per peer
|
||||
latestUpdates sync.Map // map[string]*peerUpdate
|
||||
// activePeers tracks which peers have active sender goroutines
|
||||
activePeers sync.Map // map[string]struct{}
|
||||
// metrics provides method to collect application metrics
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
@@ -31,87 +33,137 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// SendUpdate sends update message to the peer's channel
|
||||
// SendUpdate stores the latest update message for a peer and notifies the sender goroutine
|
||||
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
p.channelsMux.RLock()
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.RUnlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
|
||||
}
|
||||
}()
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
case channel <- update:
|
||||
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
|
||||
default:
|
||||
dropped = true
|
||||
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
|
||||
// Check if peer has an active sender goroutine
|
||||
if _, ok := p.activePeers.Load(peerID); !ok {
|
||||
log.WithContext(ctx).Debugf("peer %s has no active sender", peerID)
|
||||
return
|
||||
}
|
||||
|
||||
found = true
|
||||
|
||||
// Load or create peerUpdate entry
|
||||
val, _ := p.latestUpdates.LoadOrStore(peerID, &peerUpdate{
|
||||
notify: make(chan struct{}, 1),
|
||||
})
|
||||
|
||||
pu := val.(*peerUpdate)
|
||||
|
||||
// Store the latest message (overwrites any previous unsent message)
|
||||
pu.mu.Lock()
|
||||
pu.message = update
|
||||
pu.mu.Unlock()
|
||||
|
||||
// Non-blocking notification
|
||||
select {
|
||||
case pu.notify <- struct{}{}:
|
||||
log.WithContext(ctx).Debugf("update notification sent for peer %s", peerID)
|
||||
default:
|
||||
// Already notified, sender will pick up the latest message anyway
|
||||
log.WithContext(ctx).Tracef("peer %s already notified, update will be picked up", peerID)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
|
||||
// CreateChannel creates a sender goroutine for a given peer and returns a channel to receive updates
|
||||
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
|
||||
start := time.Now()
|
||||
|
||||
closed := false
|
||||
|
||||
p.channelsMux.Lock()
|
||||
defer func() {
|
||||
p.channelsMux.Unlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed)
|
||||
}
|
||||
}()
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
// Close existing sender if any
|
||||
if _, exists := p.activePeers.LoadOrStore(peerID, struct{}{}); exists {
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
p.closeChannel(ctx, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
p.peerChannels[peerID] = channel
|
||||
|
||||
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
|
||||
// Create peerUpdate entry with notification channel
|
||||
pu := &peerUpdate{
|
||||
notify: make(chan struct{}, 1),
|
||||
}
|
||||
p.latestUpdates.Store(peerID, pu)
|
||||
|
||||
return channel
|
||||
// Create output channel for consumer
|
||||
outChan := make(chan *UpdateMessage, 1)
|
||||
|
||||
// Start sender goroutine
|
||||
go func() {
|
||||
defer close(outChan)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped due to context cancellation", peerID)
|
||||
return
|
||||
case <-pu.notify:
|
||||
// Check if still active
|
||||
if _, ok := p.activePeers.Load(peerID); !ok {
|
||||
log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped", peerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the latest message with mutex protection
|
||||
pu.mu.Lock()
|
||||
msg := pu.message
|
||||
pu.message = nil // Clear after reading
|
||||
pu.mu.Unlock()
|
||||
|
||||
if msg != nil {
|
||||
select {
|
||||
case outChan <- msg:
|
||||
log.WithContext(ctx).Tracef("sent update to peer %s", peerID)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
log.WithContext(ctx).Debugf("created sender goroutine for peer %s", peerID)
|
||||
|
||||
return outChan
|
||||
}
|
||||
|
||||
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
// Mark peer as inactive to stop the sender goroutine
|
||||
if _, ok := p.activePeers.LoadAndDelete(peerID); ok {
|
||||
// Close notification channel
|
||||
if val, ok := p.latestUpdates.Load(peerID); ok {
|
||||
pu := val.(*peerUpdate)
|
||||
close(pu.notify)
|
||||
}
|
||||
p.latestUpdates.Delete(peerID)
|
||||
log.WithContext(ctx).Debugf("closed sender for peer %s", peerID)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
|
||||
log.WithContext(ctx).Debugf("closing sender: peer %s has no active sender", peerID)
|
||||
}
|
||||
|
||||
// CloseChannels closes updates channel for each given peer
|
||||
// CloseChannels closes sender goroutines for each given peer
|
||||
func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) {
|
||||
start := time.Now()
|
||||
|
||||
p.channelsMux.Lock()
|
||||
defer func() {
|
||||
p.channelsMux.Unlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs))
|
||||
}
|
||||
@@ -122,13 +174,11 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string
|
||||
}
|
||||
}
|
||||
|
||||
// CloseChannel closes updates channel of a given peer
|
||||
// CloseChannel closes the sender goroutine of a given peer
|
||||
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
|
||||
start := time.Now()
|
||||
|
||||
p.channelsMux.Lock()
|
||||
defer func() {
|
||||
p.channelsMux.Unlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start))
|
||||
}
|
||||
@@ -141,38 +191,43 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
|
||||
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
|
||||
start := time.Now()
|
||||
|
||||
p.channelsMux.RLock()
|
||||
|
||||
m := make(map[string]struct{})
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.RUnlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
|
||||
}
|
||||
}()
|
||||
|
||||
for ID := range p.peerChannels {
|
||||
m[ID] = struct{}{}
|
||||
}
|
||||
p.activePeers.Range(func(key, value interface{}) bool {
|
||||
m[key.(string)] = struct{}{}
|
||||
return true
|
||||
})
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// HasChannel returns true if peers has channel in update manager, otherwise false
|
||||
// HasChannel returns true if peer has an active sender goroutine, otherwise false
|
||||
func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
start := time.Now()
|
||||
|
||||
p.channelsMux.RLock()
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.RUnlock()
|
||||
if p.metrics != nil {
|
||||
p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start))
|
||||
}
|
||||
}()
|
||||
|
||||
_, ok := p.peerChannels[peerID]
|
||||
_, ok := p.activePeers.Load(peerID)
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetChannelCount returns the number of active peer channels
|
||||
func (p *PeersUpdateManager) GetChannelCount() int {
|
||||
count := 0
|
||||
p.activePeers.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"runtime"
|
||||
)
|
||||
import "golang.org/x/sys/windows/registry"
|
||||
|
||||
const (
|
||||
urlWinExe = "https://pkgs.netbird.io/windows/x64"
|
||||
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
|
||||
)
|
||||
|
||||
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
|
||||
@@ -15,14 +11,9 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
|
||||
// DownloadUrl return with the proper download link
|
||||
func DownloadUrl() string {
|
||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if err == nil {
|
||||
return urlWinExe
|
||||
} else {
|
||||
return downloadURL
|
||||
}
|
||||
|
||||
url := urlWinExe
|
||||
if runtime.GOARCH == "arm64" {
|
||||
url = urlWinExeArm
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user