Compare commits

..

12 Commits

Author SHA1 Message Date
Pascal Fischer
df101bf071 add custom store cache 2025-10-17 15:57:40 +02:00
Pascal Fischer
8393bf1b17 pass metrics in transactions 2025-10-15 16:29:36 +02:00
Pascal Fischer
02a04958e7 add metrics to store methods 2025-10-14 21:28:52 +02:00
Viktor Liu
000e99e7f3 [client] Force TLS1.2 for RDP with Win11/Server2025 for CredSSP compatibility (#4617) 2025-10-13 17:50:16 +02:00
Maycon Santos
0d2e67983a [misc] Add service definition for netbird-signal (#4620) 2025-10-10 19:16:48 +02:00
Pascal Fischer
5151f19d29 [management] pass temporary flag to validator (#4599) 2025-10-10 16:15:51 +02:00
Kostya Leschenko
bedd3cabc9 [client] Explicitly disable DNSOverTLS for systemd-resolved (#4579) 2025-10-10 15:24:24 +02:00
hakansa
d35a845dbd [management] sync all other peers on peer add/remove (#4614) 2025-10-09 21:18:00 +02:00
hakansa
4e03f708a4 fix dns forwarder port update (#4613)
fix dns forwarder port update (#4613)
2025-10-09 17:39:02 +03:00
Ashley
654aa9581d [client,gui] Update url_windows.go to offer arm64 executable download (#4586) 2025-10-08 21:27:32 +02:00
Zoltan Papp
9021bb512b [client] Recreate agent when receive new session id (#4564)
When an ICE agent connection was in progress, new offers were being ignored. This was incorrect logic because the remote agent could be restarted at any time.
In this change, whenever a new session ID is received, the ongoing handshake is closed and a new one is started.
2025-10-08 17:14:24 +02:00
hakansa
768332820e [client] Implement DNS query caching in DNSForwarder (#4574)
implements DNS query caching in the DNSForwarder to improve performance and provide fallback responses when upstream DNS servers fail. The cache stores successful DNS query results and serves them when upstream resolution fails.

- Added a new cache component to store DNS query results by domain and query type
- Integrated cache storage after successful DNS resolutions
- Enhanced error handling to serve cached responses as fallback when upstream DNS fails
2025-10-08 16:54:27 +02:00
34 changed files with 2284 additions and 571 deletions

View File

@@ -31,6 +31,7 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -102,6 +103,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err)
}
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var (
searchDomains []string
matchDomains []string

View File

@@ -0,0 +1,78 @@
package dnsfwd
import (
"net/netip"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
)
type cache struct {
mu sync.RWMutex
records map[string]*cacheEntry
}
type cacheEntry struct {
ip4Addrs []netip.Addr
ip6Addrs []netip.Addr
}
func newCache() *cache {
return &cache{
records: make(map[string]*cacheEntry),
}
}
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.records[normalizeDomain(domain)]
if !exists {
return nil, false
}
switch reqType {
case dns.TypeA:
return slices.Clone(entry.ip4Addrs), true
case dns.TypeAAAA:
return slices.Clone(entry.ip6Addrs), true
default:
return nil, false
}
}
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
norm := normalizeDomain(domain)
entry, exists := c.records[norm]
if !exists {
entry = &cacheEntry{}
c.records[norm] = entry
}
switch reqType {
case dns.TypeA:
entry.ip4Addrs = slices.Clone(addrs)
case dns.TypeAAAA:
entry.ip6Addrs = slices.Clone(addrs)
}
}
// unset removes cached entries for the given domain and request type.
func (c *cache) unset(domain string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.records, normalizeDomain(domain))
}
// normalizeDomain converts an input domain into a canonical form used as cache key:
// lowercase and fully-qualified (with trailing dot).
func normalizeDomain(domain string) string {
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
return dns.Fqdn(strings.ToLower(domain))
}

View File

@@ -0,0 +1,86 @@
package dnsfwd
import (
"net/netip"
"testing"
)
func mustAddr(t *testing.T, s string) netip.Addr {
t.Helper()
a, err := netip.ParseAddr(s)
if err != nil {
t.Fatalf("parse addr %s: %v", s, err)
}
return a
}
func TestCacheNormalization(t *testing.T) {
c := newCache()
// Mixed case, without trailing dot
domainInput := "ExAmPlE.CoM"
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
// Lookup with lower, with trailing dot
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
}
// Lookup with different casing again
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
}
}
func TestCacheSeparateTypes(t *testing.T) {
c := newCache()
domain := "test.local"
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
c.set(domain, 1 /* A */, ipv4)
c.set(domain, 28 /* AAAA */, ipv6)
got4, ok4 := c.get(domain, 1)
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
}
got6, ok6 := c.get(domain, 28)
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
}
}
func TestCacheCloneOnGetAndSet(t *testing.T) {
c := newCache()
domain := "clone.test"
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
c.set(domain, 1, src)
// Mutate source slice; cache should be unaffected
src[0] = mustAddr(t, "9.9.9.9")
got, ok := c.get(domain, 1)
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
}
// Mutate returned slice; internal cache should remain unchanged
got[0] = mustAddr(t, "4.4.4.4")
got2, ok2 := c.get(domain, 1)
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
}
}
func TestCacheMiss(t *testing.T) {
c := newCache()
if got, ok := c.get("missing.example", 1); ok || got != nil {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
}
}

View File

@@ -46,6 +46,7 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry
firewall firewaller
resolver resolver
cache *cache
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -56,6 +57,7 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
cache: newCache(),
}
}
@@ -103,10 +105,39 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
// remove cache entries for domains that no longer appear
f.removeStaleCacheEntries(f.fwdEntries, entries)
f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
}
// removeStaleCacheEntries unsets cache items for domains that were present
// in the old list but not present in the new list.
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
if f.cache == nil {
return
}
newSet := make(map[string]struct{}, len(newEntries))
for _, e := range newEntries {
if e == nil {
continue
}
newSet[e.Domain.PunycodeString()] = struct{}{}
}
for _, e := range oldEntries {
if e == nil {
continue
}
pattern := e.Domain.PunycodeString()
if _, ok := newSet[pattern]; !ok {
f.cache.unset(pattern)
}
}
}
func (f *DNSForwarder) Close(ctx context.Context) error {
var result *multierror.Error
@@ -171,6 +202,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp
}
@@ -282,29 +314,69 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
default:
resp.Rcode = dns.RcodeServerFailure
f.cache.set(domain, question.Qtype, nil)
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
}

View File

@@ -648,6 +648,95 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
// Ensures that when the first query succeeds and populates the cache,
// a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("1.2.3.4")
// First call resolves successfully and populates cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{ip}, nil).Once()
// Second call fails upstream; forwarder should serve from cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
// First query: populate cache
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
// Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("ExAmPlE.CoM")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("9.8.7.6")
// Initial resolution with mixed case to populate cache
mixedQuery := "ExAmPlE.CoM"
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
Return([]netip.Addr{ip}, nil).Once()
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Subsequent query without dot and upper case should hit cache even if upstream fails
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}

View File

@@ -40,7 +40,6 @@ type Manager struct {
fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder
port uint16
}
func ListenPort() uint16 {
@@ -49,11 +48,16 @@ func ListenPort() uint16 {
return listenPort
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{
firewall: fw,
statusRecorder: statusRecorder,
port: port,
}
}
@@ -67,12 +71,6 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err
}
if m.port > 0 {
listenPortMu.Lock()
listenPort = m.port
listenPortMu.Unlock()
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {

View File

@@ -1849,6 +1849,10 @@ func (e *Engine) updateDNSForwarder(
return
}
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1862,7 +1866,7 @@ func (e *Engine) updateDNSForwarder(
if len(fwdEntries) > 0 {
switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
@@ -1892,7 +1896,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil

View File

@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)

View File

@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return
}
onNewOffeChan := make(chan struct{})
onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{}
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOfferChan <- struct{}{}
})
conn.OnRemoteOffer(OfferAnswer{
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer cancel()
select {
case <-onNewOffeChan:
case <-onNewOfferChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return
}
onNewOffeChan := make(chan struct{})
onNewOfferChan := make(chan struct{})
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{}
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOfferChan <- struct{}{}
})
conn.OnRemoteAnswer(OfferAnswer{
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer cancel()
select {
case <-onNewOffeChan:
case <-onNewOfferChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")

View File

@@ -0,0 +1,20 @@
package guard
import (
"os"
"strconv"
"time"
)
const (
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
)
func GetICEMonitorPeriod() time.Duration {
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return defaultCandidatesMonitorPeriod
}

View File

@@ -16,8 +16,8 @@ import (
)
const (
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
defaultCandidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ICEMonitor struct {
@@ -25,16 +25,19 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config
tickerPeriod time.Duration
currentCandidatesAddress []string
candidatesMu sync.Mutex
}
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
log.Debugf("prepare ICE monitor with period: %s", period)
cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover,
iceConfig: config,
tickerPeriod: period,
}
return cm
}
@@ -46,7 +49,12 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
return
}
ticker := time.NewTicker(candidatesMonitorPeriod)
// Initial check to populate the candidates for later comparison
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
log.Warnf("Failed to check initial ICE candidates: %v", err)
}
ticker := time.NewTicker(cm.tickerPeriod)
defer ticker.Stop()
for {

View File

@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -44,13 +44,19 @@ type OfferAnswer struct {
}
type Handshaker struct {
mu sync.Mutex
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
onNewOfferListeners []*OfferListener
mu sync.Mutex
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
// relayListener is not blocking because the listener is using a goroutine to process the messages
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
// and also to avoid processing old offers if multiple offers are received in a short time
// the listener will always process the latest offer
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
@@ -70,28 +76,39 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
}
}
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
l := NewOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.relayListener = NewAsyncOfferListener(offer)
}
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.iceListener = offer
}
func (h *Handshaker) Listen(ctx context.Context) {
for {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue
}
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")

View File

@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
return oa.SessionID.String()
}
type OfferListener struct {
type AsyncOfferListener struct {
fn callbackFunc
running bool
latest *OfferAnswer
mu sync.Mutex
}
func NewOfferListener(fn callbackFunc) *OfferListener {
return &OfferListener{
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
return &AsyncOfferListener{
fn: fn,
}
}
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock()
defer o.mu.Unlock()

View File

@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
runChan <- struct{}{}
}
hl := NewOfferListener(longRunningFn)
hl := NewAsyncOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)

View File

@@ -92,23 +92,16 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
if w.agent != nil || w.agentConnecting {
// backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return
}
w.log.Debugf("agent already exists, recreate the connection")
@@ -116,6 +109,12 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil
}
@@ -126,18 +125,23 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = icemaker.CandidateTypes()
}
w.log.Debugf("recreate ICE agent")
if remoteOfferAnswer.SessionID != nil {
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
}
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return
}
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
w.muxAgent.Unlock()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
} else {
w.remoteSessionID = ""
}
go w.connect(dialerCtx, agent, remoteOfferAnswer)
}
@@ -293,9 +297,6 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.muxAgent.Lock()
w.agentConnecting = false
w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock()
// todo: the potential problem is a race between the onConnectionStateChange
@@ -309,16 +310,17 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
}
w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil
w.agentConnecting = false
w.remoteSessionID = ""
}
w.muxAgent.Unlock()
}
@@ -395,11 +397,12 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agent, dialerCancel)
default:
return
}

View File

@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
@@ -93,4 +93,15 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tl
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -6,11 +6,13 @@ import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
@@ -19,18 +21,34 @@ const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
}
type RDCleanPathProxy struct {
@@ -210,9 +228,13 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
@@ -220,6 +242,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -227,6 +250,7 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -269,3 +293,52 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -3,6 +3,7 @@
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
@@ -11,11 +12,17 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, "Unsupported version")
p.sendRDCleanPathError(conn, newHTTPError(400))
return
}
@@ -24,10 +31,13 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination
}
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, "Connection failed")
p.sendRDCleanPathError(conn, newWSAError(err))
p.cleanupConnection(conn)
return
}
@@ -40,6 +50,34 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
@@ -47,7 +85,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -55,21 +93,32 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
tlsConfig := p.getTLSConfigWithValidation(conn)
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, "TLS handshake failed")
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -106,47 +155,6 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
@@ -158,21 +166,6 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)

2
go.mod
View File

@@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -49,6 +49,7 @@ services:
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.routers.netbird-signal.service=netbird-signal
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -86,7 +86,7 @@ func NewServer(
if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(peersUpdateManager.GetChannelCount())
return int64(len(peersUpdateManager.peerChannels))
})
if err != nil {
return nil, err

View File

@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
return peer
}

View File

@@ -3,16 +3,16 @@ package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error

View File

@@ -350,7 +350,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
var peer *nbpeer.Peer
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -363,11 +362,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
if err != nil {
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -387,7 +381,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if updateAccountPeers && userID != activity.SystemInitiator {
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -584,7 +578,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -684,11 +678,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
@@ -701,9 +690,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID)
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
@@ -1270,10 +1257,12 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer)
}
//
wg.Wait()
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
@@ -1379,7 +1368,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1525,16 +1514,6 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
if err != nil {
return false, err
}
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
@@ -1601,6 +1580,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
},
},
},
NetworkMap: &types.NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {

View File

@@ -1043,8 +1043,8 @@ func TestUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
// assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
// assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
}
})
}
@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg) //
peerShouldReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -0,0 +1,129 @@
package cache
import (
"context"
"sync"
)
// DualKeyCache provides a caching mechanism where each entry has two keys:
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
reverseLookup map[K1]K2 // Primary key -> Secondary key
}
// NewDualKeyCache creates a new dual-key cache
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
return &DualKeyCache[K1, K2, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
reverseLookup: make(map[K1]K2),
}
}
// Get retrieves a value from the cache using the primary key
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with both primary and secondary keys
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldSecondaryKey)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = secondaryKey
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, secondaryKey)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.reverseLookup = make(map[K1]K2)
}
// Size returns the number of entries in the cache
func (c *DualKeyCache[K1, K2, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return both the value and the secondary key (extracted from the value)
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, value)
return value, nil
}

View File

@@ -0,0 +1,77 @@
package cache
import (
"context"
"sync"
)
// SingleKeyCache provides a simple caching mechanism with a single key
type SingleKeyCache[K comparable, V any] struct {
mu sync.RWMutex
cache map[K]V // Key -> Value
}
// NewSingleKeyCache creates a new single-key cache
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
return &SingleKeyCache[K, V]{
cache: make(map[K]V),
}
}
// Get retrieves a value from the cache using the key
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.cache[key]
return value, ok
}
// Set stores a value in the cache with the given key
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = value
}
// Invalidate removes an entry using the key
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
// InvalidateAll removes all entries from the cache
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]V)
}
// Size returns the number of entries in the cache
func (c *SingleKeyCache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
if value, ok := c.Get(ctx, key); ok {
return value, nil
}
value, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, key, value)
return value, nil
}

View File

@@ -0,0 +1,242 @@
package cache
import (
"context"
"sync"
)
// TripleKeyCache provides a caching mechanism where each entry has three keys:
// - Primary key (K1): used for accessing and invalidating specific entries
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
}
type keyPair[K2 comparable, K3 comparable] struct {
secondary K2
tertiary K3
}
// NewTripleKeyCache creates a new triple-key cache
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
return &TripleKeyCache[K1, K2, K3, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
tertiaryIndex: make(map[K3]map[K1]struct{}),
reverseLookup: make(map[K1]keyPair[K2, K3]),
}
}
// Get retrieves a value from the cache using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with primary, secondary, and tertiary keys
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldKeys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, oldKeys.tertiary)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
secondary: secondaryKey,
tertiary: tertiaryKey,
}
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
}
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if keys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
delete(tertiaryPrimaryKeys, primaryKey)
if len(tertiaryPrimaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateByTertiaryKey removes all entries with the given tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
delete(secondaryPrimaryKeys, primaryKey)
if len(secondaryPrimaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.tertiaryIndex, tertiaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.tertiaryIndex = make(map[K3]map[K1]struct{})
c.reverseLookup = make(map[K1]keyPair[K2, K3])
}
// Size returns the number of entries in the cache
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,6 +4,7 @@ import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -14,6 +15,8 @@ type StoreMetrics struct {
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
queryDurationMs metric.Int64Histogram
queryCounter metric.Int64Counter
ctx context.Context
}
@@ -59,12 +62,29 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err
}
queryDurationMs, err := meter.Int64Histogram("management.store.query.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of database query operations with operation type and table name"),
)
if err != nil {
return nil, err
}
queryCounter, err := meter.Int64Counter("management.store.query.count",
metric.WithDescription("Count of database query operations with operation type, table name, and status"),
)
if err != nil {
return nil, err
}
return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
queryDurationMs: queryDurationMs,
queryCounter: queryCounter,
ctx: ctx,
}, nil
}
@@ -85,3 +105,13 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}
// CountStoreOperation records a store operation with its method name, status, and duration
func (metrics *StoreMetrics) CountStoreOperation(method string, duration time.Duration) {
attrs := []attribute.KeyValue{
attribute.String("method", method),
}
metrics.queryDurationMs.Record(metrics.ctx, duration.Milliseconds(), metric.WithAttributes(attrs...))
metrics.queryCounter.Add(metrics.ctx, 1, metric.WithAttributes(attrs...))
}

View File

@@ -7,25 +7,23 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
type UpdateMessage struct {
Update *proto.SyncResponse
}
const channelBufferSize = 100
type peerUpdate struct {
mu sync.Mutex
message *UpdateMessage
notify chan struct{}
type UpdateMessage struct {
Update *proto.SyncResponse
NetworkMap *types.NetworkMap
}
type PeersUpdateManager struct {
// latestUpdates stores the latest update message per peer
latestUpdates sync.Map // map[string]*peerUpdate
// activePeers tracks which peers have active sender goroutines
activePeers sync.Map // map[string]struct{}
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
@@ -33,137 +31,87 @@ type PeersUpdateManager struct {
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
metrics: metrics,
peerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
// SendUpdate stores the latest update message for a peer and notifies the sender goroutine
// SendUpdate sends update message to the peer's channel
func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) {
start := time.Now()
var found, dropped bool
p.channelsMux.RLock()
defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountSendUpdateDuration(time.Since(start), found, dropped)
}
}()
// Check if peer has an active sender goroutine
if _, ok := p.activePeers.Load(peerID); !ok {
log.WithContext(ctx).Debugf("peer %s has no active sender", peerID)
return
}
found = true
// Load or create peerUpdate entry
val, _ := p.latestUpdates.LoadOrStore(peerID, &peerUpdate{
notify: make(chan struct{}, 1),
})
pu := val.(*peerUpdate)
// Store the latest message (overwrites any previous unsent message)
pu.mu.Lock()
pu.message = update
pu.mu.Unlock()
// Non-blocking notification
select {
case pu.notify <- struct{}{}:
log.WithContext(ctx).Debugf("update notification sent for peer %s", peerID)
default:
// Already notified, sender will pick up the latest message anyway
log.WithContext(ctx).Tracef("peer %s already notified, update will be picked up", peerID)
if channel, ok := p.peerChannels[peerID]; ok {
found = true
select {
case channel <- update:
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
}
} else {
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
}
}
// CreateChannel creates a sender goroutine for a given peer and returns a channel to receive updates
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
start := time.Now()
closed := false
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCreateChannelDuration(time.Since(start), closed)
}
}()
// Close existing sender if any
if _, exists := p.activePeers.LoadOrStore(peerID, struct{}{}); exists {
if channel, ok := p.peerChannels[peerID]; ok {
closed = true
p.closeChannel(ctx, peerID)
delete(p.peerChannels, peerID)
close(channel)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
// Create peerUpdate entry with notification channel
pu := &peerUpdate{
notify: make(chan struct{}, 1),
}
p.latestUpdates.Store(peerID, pu)
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
// Create output channel for consumer
outChan := make(chan *UpdateMessage, 1)
// Start sender goroutine
go func() {
defer close(outChan)
for {
select {
case <-ctx.Done():
log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped due to context cancellation", peerID)
return
case <-pu.notify:
// Check if still active
if _, ok := p.activePeers.Load(peerID); !ok {
log.WithContext(ctx).Debugf("sender goroutine for peer %s stopped", peerID)
return
}
// Get the latest message with mutex protection
pu.mu.Lock()
msg := pu.message
pu.message = nil // Clear after reading
pu.mu.Unlock()
if msg != nil {
select {
case outChan <- msg:
log.WithContext(ctx).Tracef("sent update to peer %s", peerID)
case <-ctx.Done():
return
}
}
}
}
}()
log.WithContext(ctx).Debugf("created sender goroutine for peer %s", peerID)
return outChan
return channel
}
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
// Mark peer as inactive to stop the sender goroutine
if _, ok := p.activePeers.LoadAndDelete(peerID); ok {
// Close notification channel
if val, ok := p.latestUpdates.Load(peerID); ok {
pu := val.(*peerUpdate)
close(pu.notify)
}
p.latestUpdates.Delete(peerID)
log.WithContext(ctx).Debugf("closed sender for peer %s", peerID)
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
return
}
log.WithContext(ctx).Debugf("closing sender: peer %s has no active sender", peerID)
log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID)
}
// CloseChannels closes sender goroutines for each given peer
// CloseChannels closes updates channel for each given peer
func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string) {
start := time.Now()
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelsDuration(time.Since(start), len(peerIDs))
}
@@ -174,11 +122,13 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string
}
}
// CloseChannel closes the sender goroutine of a given peer
// CloseChannel closes updates channel of a given peer
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
start := time.Now()
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountCloseChannelDuration(time.Since(start))
}
@@ -191,43 +141,38 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
start := time.Now()
p.channelsMux.RLock()
m := make(map[string]struct{})
defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountGetAllConnectedPeersDuration(time.Since(start), len(m))
}
}()
p.activePeers.Range(func(key, value interface{}) bool {
m[key.(string)] = struct{}{}
return true
})
for ID := range p.peerChannels {
m[ID] = struct{}{}
}
return m
}
// HasChannel returns true if peer has an active sender goroutine, otherwise false
// HasChannel returns true if peers has channel in update manager, otherwise false
func (p *PeersUpdateManager) HasChannel(peerID string) bool {
start := time.Now()
p.channelsMux.RLock()
defer func() {
p.channelsMux.RUnlock()
if p.metrics != nil {
p.metrics.UpdateChannelMetrics().CountHasChannelDuration(time.Since(start))
}
}()
_, ok := p.activePeers.Load(peerID)
_, ok := p.peerChannels[peerID]
return ok
}
// GetChannelCount returns the number of active peer channels
func (p *PeersUpdateManager) GetChannelCount() int {
count := 0
p.activePeers.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}

View File

@@ -1,9 +1,13 @@
package version
import "golang.org/x/sys/windows/registry"
import (
"golang.org/x/sys/windows/registry"
"runtime"
)
const (
urlWinExe = "https://pkgs.netbird.io/windows/x64"
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
)
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
@@ -11,9 +15,14 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
// DownloadUrl return with the proper download link
func DownloadUrl() string {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
if err == nil {
return urlWinExe
} else {
if err != nil {
return downloadURL
}
url := urlWinExe
if runtime.GOARCH == "arm64" {
url = urlWinExeArm
}
return url
}