Compare commits

..

1 Commits

Author SHA1 Message Date
Pascal Fischer
e67b7ca110 basic store cache 2025-10-06 16:16:32 +02:00
47 changed files with 736 additions and 2449 deletions

View File

@@ -29,8 +29,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
// for js, the outer websocket layer takes care of tls
if tlsEnabled && runtime.GOOS != "js" {
if tlsEnabled {
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -38,7 +37,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
}
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool,
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
}))
}

View File

@@ -73,44 +73,6 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil
}
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {

View File

@@ -106,67 +106,6 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil
}
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {

View File

@@ -21,5 +21,4 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
}

View File

@@ -148,17 +148,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock()

View File

@@ -31,7 +31,6 @@ const (
systemdDbusSetDefaultRouteMethodSuffix = systemdDbusLinkInterface + ".SetDefaultRoute"
systemdDbusSetDomainsMethodSuffix = systemdDbusLinkInterface + ".SetDomains"
systemdDbusSetDNSSECMethodSuffix = systemdDbusLinkInterface + ".SetDNSSEC"
systemdDbusSetDNSOverTLSMethodSuffix = systemdDbusLinkInterface + ".SetDNSOverTLS"
systemdDbusResolvConfModeForeign = "foreign"
dbusErrorUnknownObject = "org.freedesktop.DBus.Error.UnknownObject"
@@ -103,11 +102,6 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Warnf("failed to set DNSSEC to 'no': %v", err)
}
// We don't support DNSOverTLS. On some machines this is default on so we explicitly set it to off
if err := s.callLinkMethod(systemdDbusSetDNSOverTLSMethodSuffix, dnsSecDisabled); err != nil {
log.Warnf("failed to set DNSOverTLS to 'no': %v", err)
}
var (
searchDomains []string
matchDomains []string

View File

@@ -1,78 +0,0 @@
package dnsfwd
import (
"net/netip"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
)
type cache struct {
mu sync.RWMutex
records map[string]*cacheEntry
}
type cacheEntry struct {
ip4Addrs []netip.Addr
ip6Addrs []netip.Addr
}
func newCache() *cache {
return &cache{
records: make(map[string]*cacheEntry),
}
}
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.records[normalizeDomain(domain)]
if !exists {
return nil, false
}
switch reqType {
case dns.TypeA:
return slices.Clone(entry.ip4Addrs), true
case dns.TypeAAAA:
return slices.Clone(entry.ip6Addrs), true
default:
return nil, false
}
}
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
norm := normalizeDomain(domain)
entry, exists := c.records[norm]
if !exists {
entry = &cacheEntry{}
c.records[norm] = entry
}
switch reqType {
case dns.TypeA:
entry.ip4Addrs = slices.Clone(addrs)
case dns.TypeAAAA:
entry.ip6Addrs = slices.Clone(addrs)
}
}
// unset removes cached entries for the given domain and request type.
func (c *cache) unset(domain string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.records, normalizeDomain(domain))
}
// normalizeDomain converts an input domain into a canonical form used as cache key:
// lowercase and fully-qualified (with trailing dot).
func normalizeDomain(domain string) string {
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
return dns.Fqdn(strings.ToLower(domain))
}

View File

@@ -1,86 +0,0 @@
package dnsfwd
import (
"net/netip"
"testing"
)
func mustAddr(t *testing.T, s string) netip.Addr {
t.Helper()
a, err := netip.ParseAddr(s)
if err != nil {
t.Fatalf("parse addr %s: %v", s, err)
}
return a
}
func TestCacheNormalization(t *testing.T) {
c := newCache()
// Mixed case, without trailing dot
domainInput := "ExAmPlE.CoM"
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
// Lookup with lower, with trailing dot
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
}
// Lookup with different casing again
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
}
}
func TestCacheSeparateTypes(t *testing.T) {
c := newCache()
domain := "test.local"
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
c.set(domain, 1 /* A */, ipv4)
c.set(domain, 28 /* AAAA */, ipv6)
got4, ok4 := c.get(domain, 1)
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
}
got6, ok6 := c.get(domain, 28)
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
}
}
func TestCacheCloneOnGetAndSet(t *testing.T) {
c := newCache()
domain := "clone.test"
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
c.set(domain, 1, src)
// Mutate source slice; cache should be unaffected
src[0] = mustAddr(t, "9.9.9.9")
got, ok := c.get(domain, 1)
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
}
// Mutate returned slice; internal cache should remain unchanged
got[0] = mustAddr(t, "4.4.4.4")
got2, ok2 := c.get(domain, 1)
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
}
}
func TestCacheMiss(t *testing.T) {
c := newCache()
if got, ok := c.get("missing.example", 1); ok || got != nil {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
}
}

View File

@@ -46,7 +46,6 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry
firewall firewaller
resolver resolver
cache *cache
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -57,7 +56,6 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
cache: newCache(),
}
}
@@ -105,39 +103,10 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock()
defer f.mutex.Unlock()
// remove cache entries for domains that no longer appear
f.removeStaleCacheEntries(f.fwdEntries, entries)
f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
}
// removeStaleCacheEntries unsets cache items for domains that were present
// in the old list but not present in the new list.
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
if f.cache == nil {
return
}
newSet := make(map[string]struct{}, len(newEntries))
for _, e := range newEntries {
if e == nil {
continue
}
newSet[e.Domain.PunycodeString()] = struct{}{}
}
for _, e := range oldEntries {
if e == nil {
continue
}
pattern := e.Domain.PunycodeString()
if _, ok := newSet[pattern]; !ok {
f.cache.unset(pattern)
}
}
}
func (f *DNSForwarder) Close(ctx context.Context) error {
var result *multierror.Error
@@ -202,7 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp
}
@@ -314,69 +282,29 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
func (f *DNSForwarder) handleDNSError(
ctx context.Context,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
// handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// NotFound: set NXDOMAIN / appropriate code via helper.
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
}
f.cache.set(domain, question.Qtype, nil)
return
}
// Upstream failed but we might have a cached answer—serve it if present.
if ips, ok := f.cache.get(domain, qType); ok {
if len(ips) > 0 {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
return
}
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
}

View File

@@ -648,95 +648,6 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
// Ensures that when the first query succeeds and populates the cache,
// a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("1.2.3.4")
// First call resolves successfully and populates cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{ip}, nil).Once()
// Second call fails upstream; forwarder should serve from cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
// First query: populate cache
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
// Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("ExAmPlE.CoM")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("9.8.7.6")
// Initial resolution with mixed case to populate cache
mixedQuery := "ExAmPlE.CoM"
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
Return([]netip.Addr{ip}, nil).Once()
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Subsequent query without dot and upper case should hit cache even if upstream fails
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}

View File

@@ -40,6 +40,7 @@ type Manager struct {
fwRules []firewall.Rule
tcpRules []firewall.Rule
dnsForwarder *DNSForwarder
port uint16
}
func ListenPort() uint16 {
@@ -48,16 +49,11 @@ func ListenPort() uint16 {
return listenPort
}
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, port uint16) *Manager {
return &Manager{
firewall: fw,
statusRecorder: statusRecorder,
port: port,
}
}
@@ -71,6 +67,12 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err
}
if m.port > 0 {
listenPortMu.Lock()
listenPort = m.port
listenPortMu.Unlock()
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder)
go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {

View File

@@ -1849,10 +1849,6 @@ func (e *Engine) updateDNSForwarder(
return
}
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled {
if e.dnsForwardMgr == nil {
return
@@ -1866,7 +1862,7 @@ func (e *Engine) updateDNSForwarder(
if len(fwdEntries) > 0 {
switch {
case e.dnsForwardMgr == nil:
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
@@ -1896,7 +1892,7 @@ func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPor
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, forwarderPort)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil

View File

@@ -105,10 +105,6 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time
}
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented")
}

View File

@@ -28,7 +28,6 @@ type wgIfaceBase interface {
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error

View File

@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
@@ -430,9 +430,6 @@ func (conn *Conn) onICEStateDisconnected() {
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
}
changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -526,9 +523,6 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
}
if conn.wgProxyRelay != nil {

View File

@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return
}
onNewOfferChan := make(chan struct{})
onNewOffeChan := make(chan struct{})
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOfferChan <- struct{}{}
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{}
})
conn.OnRemoteOffer(OfferAnswer{
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer cancel()
select {
case <-onNewOfferChan:
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return
}
onNewOfferChan := make(chan struct{})
onNewOffeChan := make(chan struct{})
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOfferChan <- struct{}{}
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOffeChan <- struct{}{}
})
conn.OnRemoteAnswer(OfferAnswer{
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer cancel()
select {
case <-onNewOfferChan:
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")

View File

@@ -1,20 +0,0 @@
package guard
import (
"os"
"strconv"
"time"
)
const (
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
)
func GetICEMonitorPeriod() time.Duration {
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return defaultCandidatesMonitorPeriod
}

View File

@@ -16,8 +16,8 @@ import (
)
const (
defaultCandidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second
)
type ICEMonitor struct {
@@ -25,19 +25,16 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config
tickerPeriod time.Duration
currentCandidatesAddress []string
candidatesMu sync.Mutex
}
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
log.Debugf("prepare ICE monitor with period: %s", period)
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover,
iceConfig: config,
tickerPeriod: period,
}
return cm
}
@@ -49,12 +46,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
return
}
// Initial check to populate the candidates for later comparison
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
log.Warnf("Failed to check initial ICE candidates: %v", err)
}
ticker := time.NewTicker(cm.tickerPeriod)
ticker := time.NewTicker(candidatesMonitorPeriod)
defer ticker.Stop()
for {

View File

@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -44,19 +44,13 @@ type OfferAnswer struct {
}
type Handshaker struct {
mu sync.Mutex
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
// relayListener is not blocking because the listener is using a goroutine to process the messages
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
// and also to avoid processing old offers if multiple offers are received in a short time
// the listener will always process the latest offer
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
mu sync.Mutex
log *log.Entry
config ConnConfig
signaler *Signaler
ice *WorkerICE
relay *WorkerRelay
onNewOfferListeners []*OfferListener
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
@@ -76,39 +70,28 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
}
}
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.relayListener = NewAsyncOfferListener(offer)
}
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.iceListener = offer
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
l := NewOfferListener(offer)
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
}
func (h *Handshaker) Listen(ctx context.Context) {
for {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
// received confirmation from the remote peer -> ready to proceed
if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue
}
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")

View File

@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
return oa.SessionID.String()
}
type AsyncOfferListener struct {
type OfferListener struct {
fn callbackFunc
running bool
latest *OfferAnswer
mu sync.Mutex
}
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
return &AsyncOfferListener{
func NewOfferListener(fn callbackFunc) *OfferListener {
return &OfferListener{
fn: fn,
}
}
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock()
defer o.mu.Unlock()

View File

@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
runChan <- struct{}{}
}
hl := NewAsyncOfferListener(longRunningFn)
hl := NewOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer)

View File

@@ -18,5 +18,4 @@ type WGIface interface {
GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy
Address() wgaddr.Address
RemoveEndpointAddress(key string) error
}

View File

@@ -92,16 +92,23 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent != nil || w.agentConnecting {
if w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
// backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return
}
w.log.Debugf("agent already exists, recreate the connection")
@@ -109,12 +116,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil
}
@@ -125,23 +126,18 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = icemaker.CandidateTypes()
}
if remoteOfferAnswer.SessionID != nil {
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
}
w.log.Debugf("recreate ICE agent")
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return
}
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
} else {
w.remoteSessionID = ""
}
w.muxAgent.Unlock()
go w.connect(dialerCtx, agent, remoteOfferAnswer)
}
@@ -297,6 +293,9 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.muxAgent.Lock()
w.agentConnecting = false
w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock()
// todo: the potential problem is a race between the onConnectionStateChange
@@ -310,17 +309,16 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
}
w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil
w.agentConnecting = false
w.remoteSessionID = ""
}
w.muxAgent.Unlock()
}
@@ -397,12 +395,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agent, dialerCancel)
default:
return
}

View File

@@ -1354,13 +1354,7 @@ func (s *serviceClient) updateConfig() error {
}
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
// It also starts a background goroutine that periodically checks if the client is already connected
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
// also cancelled when the window is closed.
func (s *serviceClient) showLoginURL() context.CancelFunc {
// create a cancellable context for the background check goroutine
ctx, cancel := context.WithCancel(s.ctx)
func (s *serviceClient) showLoginURL() {
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
@@ -1369,8 +1363,6 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
s.wLoginURL.Resize(fyne.NewSize(400, 200))
s.wLoginURL.SetIcon(resIcon)
}
// ensure goroutine is cancelled when the window is closed
s.wLoginURL.SetOnClosed(func() { cancel() })
// add a description label
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
@@ -1451,39 +1443,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
)
s.wLoginURL.SetContent(container.NewCenter(content))
// start a goroutine to check connection status and close the window if connected
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
continue
}
if status.Status == string(internal.StatusConnected) {
if s.wLoginURL != nil {
s.wLoginURL.Close()
}
return
}
}
}
}()
s.wLoginURL.Show()
// return cancel func so callers can stop the background goroutine if desired
return cancel
}
func openURL(url string) error {

View File

@@ -73,8 +73,8 @@ func (p *RDCleanPathProxy) validateCertificateWithJS(conn *proxyConnection, cert
}
}
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, requiresCredSSP bool) *tls.Config {
config := &tls.Config{
func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection) *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // We'll validate manually after handshake
VerifyConnection: func(cs tls.ConnectionState) error {
var certChain [][]byte
@@ -93,15 +93,4 @@ func (p *RDCleanPathProxy) getTLSConfigWithValidation(conn *proxyConnection, req
return nil
},
}
// CredSSP (NLA) requires TLS 1.2 - it's incompatible with TLS 1.3
if requiresCredSSP {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS12
} else {
config.MinVersion = tls.VersionTLS12
config.MaxVersion = tls.VersionTLS13
}
return config
}

View File

@@ -6,13 +6,11 @@ import (
"context"
"crypto/tls"
"encoding/asn1"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"time"
log "github.com/sirupsen/logrus"
)
@@ -21,34 +19,18 @@ const (
RDCleanPathVersion = 3390
RDCleanPathProxyHost = "rdcleanpath.proxy.local"
RDCleanPathProxyScheme = "ws"
rdpDialTimeout = 15 * time.Second
GeneralErrorCode = 1
WSAETimedOut = 10060
WSAEConnRefused = 10061
WSAEConnAborted = 10053
WSAEConnReset = 10054
WSAEGenericError = 10050
)
type RDCleanPathPDU struct {
Version int64 `asn1:"tag:0,explicit"`
Error RDCleanPathErr `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathErr struct {
ErrorCode int16 `asn1:"tag:0,explicit"`
HTTPStatusCode int16 `asn1:"tag:1,explicit,optional"`
WSALastError int16 `asn1:"tag:2,explicit,optional"`
TLSAlertCode int8 `asn1:"tag:3,explicit,optional"`
Version int64 `asn1:"tag:0,explicit"`
Error []byte `asn1:"tag:1,explicit,optional"`
Destination string `asn1:"utf8,tag:2,explicit,optional"`
ProxyAuth string `asn1:"utf8,tag:3,explicit,optional"`
ServerAuth string `asn1:"utf8,tag:4,explicit,optional"`
PreconnectionBlob string `asn1:"utf8,tag:5,explicit,optional"`
X224ConnectionPDU []byte `asn1:"tag:6,explicit,optional"`
ServerCertChain [][]byte `asn1:"tag:7,explicit,optional"`
ServerAddr string `asn1:"utf8,tag:9,explicit,optional"`
}
type RDCleanPathProxy struct {
@@ -228,13 +210,9 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
destination := conn.destination
log.Infof("Direct RDP mode: Connecting to %s via NetBird", destination)
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
conn.rdpConn = rdpConn
@@ -242,7 +220,6 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
_, err = rdpConn.Write(firstPacket)
if err != nil {
log.Errorf("Failed to write first packet: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -250,7 +227,6 @@ func (p *RDCleanPathProxy) handleDirectRDP(conn *proxyConnection, firstPacket []
n, err := rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
return
}
@@ -293,52 +269,3 @@ func (p *RDCleanPathProxy) sendToWebSocket(conn *proxyConnection, data []byte) {
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
}
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func errorToWSACode(err error) int16 {
if err == nil {
return WSAEGenericError
}
var netErr *net.OpError
if errors.As(err, &netErr) && netErr.Timeout() {
return WSAETimedOut
}
if errors.Is(err, context.DeadlineExceeded) {
return WSAETimedOut
}
if errors.Is(err, context.Canceled) {
return WSAEConnAborted
}
if errors.Is(err, io.EOF) {
return WSAEConnReset
}
return WSAEGenericError
}
func newWSAError(err error) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
WSALastError: errorToWSACode(err),
},
}
}
func newHTTPError(statusCode int16) RDCleanPathPDU {
return RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: RDCleanPathErr{
ErrorCode: GeneralErrorCode,
HTTPStatusCode: statusCode,
},
}
}

View File

@@ -3,7 +3,6 @@
package rdp
import (
"context"
"crypto/tls"
"encoding/asn1"
"io"
@@ -12,17 +11,11 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// MS-RDPBCGR: confusingly named, actually means PROTOCOL_HYBRID (CredSSP)
protocolSSL = 0x00000001
protocolHybridEx = 0x00000008
)
func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
log.Infof("Processing RDCleanPath PDU: Version=%d, Destination=%s", pdu.Version, pdu.Destination)
if pdu.Version != RDCleanPathVersion {
p.sendRDCleanPathError(conn, newHTTPError(400))
p.sendRDCleanPathError(conn, "Unsupported version")
return
}
@@ -31,13 +24,10 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
destination = pdu.Destination
}
ctx, cancel := context.WithTimeout(conn.ctx, rdpDialTimeout)
defer cancel()
rdpConn, err := p.nbClient.Dial(ctx, "tcp", destination)
rdpConn, err := p.nbClient.Dial(conn.ctx, "tcp", destination)
if err != nil {
log.Errorf("Failed to connect to %s: %v", destination, err)
p.sendRDCleanPathError(conn, newWSAError(err))
p.sendRDCleanPathError(conn, "Connection failed")
p.cleanupConnection(conn)
return
}
@@ -50,34 +40,6 @@ func (p *RDCleanPathProxy) processRDCleanPathPDU(conn *proxyConnection, pdu RDCl
p.setupTLSConnection(conn, pdu)
}
// detectCredSSPFromX224 checks if the X.224 response indicates NLA/CredSSP is required.
// Per MS-RDPBCGR spec: byte 11 = TYPE_RDP_NEG_RSP (0x02), bytes 15-18 = selectedProtocol flags.
// Returns (requiresTLS12, selectedProtocol, detectionSuccessful).
func (p *RDCleanPathProxy) detectCredSSPFromX224(x224Response []byte) (bool, uint32, bool) {
const minResponseLength = 19
if len(x224Response) < minResponseLength {
return false, 0, false
}
// Per X.224 specification:
// x224Response[0] == 0x03: Length of X.224 header (3 bytes)
// x224Response[5] == 0xD0: X.224 Data TPDU code
if x224Response[0] != 0x03 || x224Response[5] != 0xD0 {
return false, 0, false
}
if x224Response[11] == 0x02 {
flags := uint32(x224Response[15]) | uint32(x224Response[16])<<8 |
uint32(x224Response[17])<<16 | uint32(x224Response[18])<<24
hasNLA := (flags & (protocolSSL | protocolHybridEx)) != 0
return hasNLA, flags, true
}
return false, 0, false
}
func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
var x224Response []byte
if len(pdu.X224ConnectionPDU) > 0 {
@@ -85,7 +47,7 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
@@ -93,32 +55,21 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
x224Response = response[:n]
log.Debugf("Received X.224 Connection Confirm (%d bytes)", n)
}
requiresCredSSP, selectedProtocol, detected := p.detectCredSSPFromX224(x224Response)
if detected {
if requiresCredSSP {
log.Warnf("Detected NLA/CredSSP (selectedProtocol: 0x%08X), forcing TLS 1.2 for compatibility", selectedProtocol)
} else {
log.Warnf("No NLA/CredSSP detected (selectedProtocol: 0x%08X), allowing up to TLS 1.3", selectedProtocol)
}
} else {
log.Warnf("Could not detect RDP security protocol, allowing up to TLS 1.3")
}
tlsConfig := p.getTLSConfigWithValidation(conn, requiresCredSSP)
tlsConfig := p.getTLSConfigWithValidation(conn)
tlsConn := tls.Client(conn.rdpConn, tlsConfig)
conn.tlsConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
log.Errorf("TLS handshake failed: %v", err)
p.sendRDCleanPathError(conn, newWSAError(err))
p.sendRDCleanPathError(conn, "TLS handshake failed")
return
}
@@ -155,6 +106,47 @@ func (p *RDCleanPathProxy) setupTLSConnection(conn *proxyConnection, pdu RDClean
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) setupPlainConnection(conn *proxyConnection, pdu RDCleanPathPDU) {
if len(pdu.X224ConnectionPDU) > 0 {
log.Debugf("Forwarding X.224 Connection Request (%d bytes)", len(pdu.X224ConnectionPDU))
_, err := conn.rdpConn.Write(pdu.X224ConnectionPDU)
if err != nil {
log.Errorf("Failed to write X.224 PDU: %v", err)
p.sendRDCleanPathError(conn, "Failed to forward X.224")
return
}
response := make([]byte, 1024)
n, err := conn.rdpConn.Read(response)
if err != nil {
log.Errorf("Failed to read X.224 response: %v", err)
p.sendRDCleanPathError(conn, "Failed to read X.224 response")
return
}
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
X224ConnectionPDU: response[:n],
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
} else {
responsePDU := RDCleanPathPDU{
Version: RDCleanPathVersion,
ServerAddr: conn.destination,
}
p.sendRDCleanPathPDU(conn, responsePDU)
}
go p.forwardConnToWS(conn, conn.rdpConn, "TCP")
go p.forwardWSToConn(conn, conn.rdpConn, "TCP")
<-conn.ctx.Done()
log.Debug("TCP connection context done, cleaning up")
p.cleanupConnection(conn)
}
func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDCleanPathPDU) {
data, err := asn1.Marshal(pdu)
if err != nil {
@@ -166,6 +158,21 @@ func (p *RDCleanPathProxy) sendRDCleanPathPDU(conn *proxyConnection, pdu RDClean
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) sendRDCleanPathError(conn *proxyConnection, errorMsg string) {
pdu := RDCleanPathPDU{
Version: RDCleanPathVersion,
Error: []byte(errorMsg),
}
data, err := asn1.Marshal(pdu)
if err != nil {
log.Errorf("Failed to marshal error PDU: %v", err)
return
}
p.sendToWebSocket(conn, data)
}
func (p *RDCleanPathProxy) readWebSocketMessage(conn *proxyConnection) ([]byte, error) {
msgChan := make(chan []byte)
errChan := make(chan error)

27
go.mod
View File

@@ -1,6 +1,8 @@
module github.com/netbirdio/netbird
go 1.23.0
go 1.24.0
toolchain go1.24.4
require (
cunicu.li/go-rosenpass v0.4.0
@@ -17,8 +19,8 @@ require (
github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.40.0
golang.org/x/sys v0.34.0
golang.org/x/crypto v0.41.0
golang.org/x/sys v0.35.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3
@@ -62,7 +64,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
@@ -102,17 +104,17 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/mod v0.25.0
golang.org/x/net v0.42.0
golang.org/x/mod v0.27.0
golang.org/x/net v0.43.0
golang.org/x/oauth2 v0.28.0
golang.org/x/sync v0.16.0
golang.org/x/term v0.33.0
golang.org/x/sync v0.17.0
golang.org/x/term v0.34.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.7
gorm.io/gorm v1.25.12
gorm.io/driver/sqlite v1.6.0
gorm.io/gorm v1.31.0
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1
)
@@ -163,6 +165,7 @@ require (
github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect
github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect
github.com/go-gorm/caches/v4 v4.0.5 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect
@@ -244,9 +247,9 @@ require (
go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/text v0.29.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.34.0 // indirect
golang.org/x/tools v0.36.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect

24
go.sum
View File

@@ -215,6 +215,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a h1:vxnBhFDDT+xzxf1jTJKMKZw3H0swfWk9RpWbBbDK5+0=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/go-gorm/caches/v4 v4.0.5 h1:Sdj9vxbEM0sCmv5+s5o6GzoVMuraWF0bjJJvUU+7c1U=
github.com/go-gorm/caches/v4 v4.0.5/go.mod h1:Ms8LnWVoW4GkTofpDzFH8OfDGNTjLxQDyxBmRN67Ujw=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
@@ -503,8 +505,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -773,6 +775,7 @@ golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98y
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -820,6 +823,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ=
golang.org/x/mod v0.27.0/go.mod h1:rWI627Fq0DEoudcK+MBkNkCe0EetEaDSwJJkCcjpazc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -867,6 +872,8 @@ golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@@ -897,6 +904,8 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@@ -966,6 +975,7 @@ golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -975,6 +985,8 @@ golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -990,6 +1002,8 @@ golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk=
golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1054,6 +1068,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -1213,9 +1229,13 @@ gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM=
gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA=
gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I=
gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY=
gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY=
gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs=

View File

@@ -47,9 +47,8 @@ services:
- traefik.enable=true
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.routers.netbird-signal.service=netbird-signal
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -621,7 +621,7 @@ renderCaddyfile() {
# relay
reverse_proxy /relay* relay:80
# Signal
reverse_proxy /ws-proxy/signal* signal:80
reverse_proxy /ws-proxy/signal* signal:10000
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management
reverse_proxy /api/* management:80

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
@@ -251,7 +252,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
}
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
switch {

View File

@@ -136,7 +136,7 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID
return validatedPeers, nil
}
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer {
func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer {
return peer
}

View File

@@ -3,16 +3,16 @@ package integrated_validator
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// IntegratedValidator interface exists to avoid the circle dependencies
type IntegratedValidator interface {
ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error
ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error)
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer
PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) *nbpeer.Peer
IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error)
GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error)
PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error

View File

@@ -350,6 +350,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
var peer *nbpeer.Peer
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -362,6 +363,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
if err != nil {
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -381,7 +387,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if userID != activity.SystemInitiator {
if updateAccountPeers && userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -578,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
}
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra, temporary)
newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -678,6 +684,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return nil, nil, nil, fmt.Errorf("failed to add peer to database after %d attempts: %w", maxAttempts, err)
}
updateAccountPeers, err := isPeerInActiveGroup(ctx, am.Store, accountID, newPeer.ID)
if err != nil {
updateAccountPeers = true
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
}
@@ -690,7 +701,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
am.BufferUpdateAccountPeers(ctx, accountID)
if updateAccountPeers {
am.BufferUpdateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
@@ -1514,6 +1527,16 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) {
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
if err != nil {
return false, err
}
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {

View File

@@ -1790,7 +1790,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg) //
peerShouldNotReceiveUpdate(t, updMsg) //
close(done)
}()
@@ -1815,7 +1815,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()

View File

@@ -1,129 +0,0 @@
package cache
import (
"context"
"sync"
)
// DualKeyCache provides a caching mechanism where each entry has two keys:
// - Primary key (e.g., objectID): used for accessing and invalidating specific entries
// - Secondary key (e.g., accountID): used for bulk invalidation of all entries with the same secondary key
type DualKeyCache[K1 comparable, K2 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
reverseLookup map[K1]K2 // Primary key -> Secondary key
}
// NewDualKeyCache creates a new dual-key cache
func NewDualKeyCache[K1 comparable, K2 comparable, V any]() *DualKeyCache[K1, K2, V] {
return &DualKeyCache[K1, K2, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
reverseLookup: make(map[K1]K2),
}
}
// Get retrieves a value from the cache using the primary key
func (c *DualKeyCache[K1, K2, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with both primary and secondary keys
func (c *DualKeyCache[K1, K2, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldSecondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldSecondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldSecondaryKey)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = secondaryKey
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *DualKeyCache[K1, K2, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if secondaryKey, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[secondaryKey]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, secondaryKey)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *DualKeyCache[K1, K2, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *DualKeyCache[K1, K2, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.reverseLookup = make(map[K1]K2)
}
// Size returns the number of entries in the cache
func (c *DualKeyCache[K1, K2, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return both the value and the secondary key (extracted from the value)
func (c *DualKeyCache[K1, K2, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, value)
return value, nil
}

View File

@@ -0,0 +1,48 @@
package cache
import (
"context"
"sync"
"github.com/go-gorm/caches/v4"
)
type MemoryCacher struct {
store *sync.Map
}
func (c *MemoryCacher) init() {
if c.store == nil {
c.store = &sync.Map{}
}
}
func (c *MemoryCacher) Get(ctx context.Context, key string, q *caches.Query[any]) (*caches.Query[any], error) {
c.init()
val, ok := c.store.Load(key)
if !ok {
return nil, nil
}
if err := q.Unmarshal(val.([]byte)); err != nil {
return nil, err
}
return q, nil
}
func (c *MemoryCacher) Store(ctx context.Context, key string, val *caches.Query[any]) error {
c.init()
res, err := val.Marshal()
if err != nil {
return err
}
c.store.Store(key, res)
return nil
}
func (c *MemoryCacher) Invalidate(ctx context.Context) error {
c.store = &sync.Map{}
return nil
}

View File

@@ -0,0 +1,73 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/go-gorm/caches/v4"
"github.com/redis/go-redis/v9"
)
type RedisCacher struct {
rdb *redis.Client
}
func NewRedisCacher(rdb *redis.Client) *RedisCacher {
return &RedisCacher{rdb: rdb}
}
func (c *RedisCacher) Get(ctx context.Context, key string, q *caches.Query[any]) (*caches.Query[any], error) {
res, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
if err := q.Unmarshal([]byte(res)); err != nil {
return nil, err
}
return q, nil
}
func (c *RedisCacher) Store(ctx context.Context, key string, val *caches.Query[any]) error {
res, err := val.Marshal()
if err != nil {
return err
}
c.rdb.Set(ctx, key, res, 300*time.Second) // Set proper cache time
return nil
}
func (c *RedisCacher) Invalidate(ctx context.Context) error {
var (
cursor uint64
keys []string
)
for {
var (
k []string
err error
)
k, cursor, err = c.rdb.Scan(ctx, cursor, fmt.Sprintf("%s*", caches.IdentifierPrefix), 0).Result()
if err != nil {
return err
}
keys = append(keys, k...)
if cursor == 0 {
break
}
}
if len(keys) > 0 {
if _, err := c.rdb.Del(ctx, keys...).Result(); err != nil {
return err
}
}
return nil
}

View File

@@ -1,77 +0,0 @@
package cache
import (
"context"
"sync"
)
// SingleKeyCache provides a simple caching mechanism with a single key
type SingleKeyCache[K comparable, V any] struct {
mu sync.RWMutex
cache map[K]V // Key -> Value
}
// NewSingleKeyCache creates a new single-key cache
func NewSingleKeyCache[K comparable, V any]() *SingleKeyCache[K, V] {
return &SingleKeyCache[K, V]{
cache: make(map[K]V),
}
}
// Get retrieves a value from the cache using the key
func (c *SingleKeyCache[K, V]) Get(ctx context.Context, key K) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.cache[key]
return value, ok
}
// Set stores a value in the cache with the given key
func (c *SingleKeyCache[K, V]) Set(ctx context.Context, key K, value V) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[key] = value
}
// Invalidate removes an entry using the key
func (c *SingleKeyCache[K, V]) Invalidate(ctx context.Context, key K) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, key)
}
// InvalidateAll removes all entries from the cache
func (c *SingleKeyCache[K, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[K]V)
}
// Size returns the number of entries in the cache
func (c *SingleKeyCache[K, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
func (c *SingleKeyCache[K, V]) GetOrSet(ctx context.Context, key K, loadFunc func() (V, error)) (V, error) {
if value, ok := c.Get(ctx, key); ok {
return value, nil
}
value, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, key, value)
return value, nil
}

View File

@@ -1,242 +0,0 @@
package cache
import (
"context"
"sync"
)
// TripleKeyCache provides a caching mechanism where each entry has three keys:
// - Primary key (K1): used for accessing and invalidating specific entries
// - Secondary key (K2): used for bulk invalidation of all entries with the same secondary key
// - Tertiary key (K3): used for bulk invalidation of all entries with the same tertiary key
type TripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any] struct {
mu sync.RWMutex
primaryIndex map[K1]V // Primary key -> Value
secondaryIndex map[K2]map[K1]struct{} // Secondary key -> Set of primary keys
tertiaryIndex map[K3]map[K1]struct{} // Tertiary key -> Set of primary keys
reverseLookup map[K1]keyPair[K2, K3] // Primary key -> Secondary and Tertiary keys
}
type keyPair[K2 comparable, K3 comparable] struct {
secondary K2
tertiary K3
}
// NewTripleKeyCache creates a new triple-key cache
func NewTripleKeyCache[K1 comparable, K2 comparable, K3 comparable, V any]() *TripleKeyCache[K1, K2, K3, V] {
return &TripleKeyCache[K1, K2, K3, V]{
primaryIndex: make(map[K1]V),
secondaryIndex: make(map[K2]map[K1]struct{}),
tertiaryIndex: make(map[K3]map[K1]struct{}),
reverseLookup: make(map[K1]keyPair[K2, K3]),
}
}
// Get retrieves a value from the cache using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) Get(ctx context.Context, primaryKey K1) (V, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
value, ok := c.primaryIndex[primaryKey]
return value, ok
}
// Set stores a value in the cache with primary, secondary, and tertiary keys
func (c *TripleKeyCache[K1, K2, K3, V]) Set(ctx context.Context, primaryKey K1, secondaryKey K2, tertiaryKey K3, value V) {
c.mu.Lock()
defer c.mu.Unlock()
if oldKeys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[oldKeys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, oldKeys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[oldKeys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, oldKeys.tertiary)
}
}
}
c.primaryIndex[primaryKey] = value
c.reverseLookup[primaryKey] = keyPair[K2, K3]{
secondary: secondaryKey,
tertiary: tertiaryKey,
}
if _, exists := c.secondaryIndex[secondaryKey]; !exists {
c.secondaryIndex[secondaryKey] = make(map[K1]struct{})
}
c.secondaryIndex[secondaryKey][primaryKey] = struct{}{}
if _, exists := c.tertiaryIndex[tertiaryKey]; !exists {
c.tertiaryIndex[tertiaryKey] = make(map[K1]struct{})
}
c.tertiaryIndex[tertiaryKey][primaryKey] = struct{}{}
}
// InvalidateByPrimaryKey removes an entry using the primary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByPrimaryKey(ctx context.Context, primaryKey K1) {
c.mu.Lock()
defer c.mu.Unlock()
if keys, exists := c.reverseLookup[primaryKey]; exists {
if primaryKeys, ok := c.secondaryIndex[keys.secondary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
if primaryKeys, ok := c.tertiaryIndex[keys.tertiary]; ok {
delete(primaryKeys, primaryKey)
if len(primaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
delete(c.reverseLookup, primaryKey)
}
delete(c.primaryIndex, primaryKey)
}
// InvalidateBySecondaryKey removes all entries with the given secondary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateBySecondaryKey(ctx context.Context, secondaryKey K2) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.secondaryIndex[secondaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if tertiaryPrimaryKeys, exists := c.tertiaryIndex[keys.tertiary]; exists {
delete(tertiaryPrimaryKeys, primaryKey)
if len(tertiaryPrimaryKeys) == 0 {
delete(c.tertiaryIndex, keys.tertiary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.secondaryIndex, secondaryKey)
}
// InvalidateByTertiaryKey removes all entries with the given tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateByTertiaryKey(ctx context.Context, tertiaryKey K3) {
c.mu.Lock()
defer c.mu.Unlock()
primaryKeys, exists := c.tertiaryIndex[tertiaryKey]
if !exists {
return
}
for primaryKey := range primaryKeys {
if keys, ok := c.reverseLookup[primaryKey]; ok {
if secondaryPrimaryKeys, exists := c.secondaryIndex[keys.secondary]; exists {
delete(secondaryPrimaryKeys, primaryKey)
if len(secondaryPrimaryKeys) == 0 {
delete(c.secondaryIndex, keys.secondary)
}
}
}
delete(c.primaryIndex, primaryKey)
delete(c.reverseLookup, primaryKey)
}
delete(c.tertiaryIndex, tertiaryKey)
}
// InvalidateAll removes all entries from the cache
func (c *TripleKeyCache[K1, K2, K3, V]) InvalidateAll(ctx context.Context) {
c.mu.Lock()
defer c.mu.Unlock()
c.primaryIndex = make(map[K1]V)
c.secondaryIndex = make(map[K2]map[K1]struct{})
c.tertiaryIndex = make(map[K3]map[K1]struct{})
c.reverseLookup = make(map[K1]keyPair[K2, K3])
}
// Size returns the number of entries in the cache
func (c *TripleKeyCache[K1, K2, K3, V]) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.primaryIndex)
}
// GetOrSet retrieves a value from the cache, or sets it using the provided function if not found
// The loadFunc should return the value, secondary key, and tertiary key (extracted from the value)
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSet(ctx context.Context, primaryKey K1, loadFunc func() (V, K2, K3, error)) (V, error) {
if value, ok := c.Get(ctx, primaryKey); ok {
return value, nil
}
value, secondaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetBySecondaryKey retrieves a value from the cache using the secondary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetBySecondaryKey(ctx context.Context, secondaryKey K2, loadFunc func() (V, K1, K3, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.secondaryIndex[secondaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, tertiaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}
// GetOrSetByTertiaryKey retrieves a value from the cache using the tertiary key, or sets it using the provided function if not found
// The loadFunc should return the value, primary key, secondary key, and tertiary key
func (c *TripleKeyCache[K1, K2, K3, V]) GetOrSetByTertiaryKey(ctx context.Context, tertiaryKey K3, loadFunc func() (V, K1, K2, error)) (V, error) {
c.mu.RLock()
if primaryKeys, exists := c.tertiaryIndex[tertiaryKey]; exists && len(primaryKeys) > 0 {
for primaryKey := range primaryKeys {
if value, ok := c.primaryIndex[primaryKey]; ok {
c.mu.RUnlock()
return value, nil
}
}
}
c.mu.RUnlock()
value, primaryKey, secondaryKey, err := loadFunc()
if err != nil {
var zero V
return zero, err
}
c.Set(ctx, primaryKey, secondaryKey, tertiaryKey, value)
return value, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,6 @@ import (
"context"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
@@ -15,8 +14,6 @@ type StoreMetrics struct {
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
transactionDurationMs metric.Int64Histogram
queryDurationMs metric.Int64Histogram
queryCounter metric.Int64Counter
ctx context.Context
}
@@ -62,29 +59,12 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err
}
queryDurationMs, err := meter.Int64Histogram("management.store.query.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of database query operations with operation type and table name"),
)
if err != nil {
return nil, err
}
queryCounter, err := meter.Int64Counter("management.store.query.count",
metric.WithDescription("Count of database query operations with operation type, table name, and status"),
)
if err != nil {
return nil, err
}
return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
transactionDurationMs: transactionDurationMs,
queryDurationMs: queryDurationMs,
queryCounter: queryCounter,
ctx: ctx,
}, nil
}
@@ -105,13 +85,3 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
}
// CountStoreOperation records a store operation with its method name, status, and duration
func (metrics *StoreMetrics) CountStoreOperation(method string, duration time.Duration) {
attrs := []attribute.KeyValue{
attribute.String("method", method),
}
metrics.queryDurationMs.Record(metrics.ctx, duration.Milliseconds(), metric.WithAttributes(attrs...))
metrics.queryCounter.Add(metrics.ctx, 1, metric.WithAttributes(attrs...))
}

View File

@@ -10,6 +10,7 @@ import (
"net/http"
// nolint:gosec
_ "net/http/pprof"
"net/netip"
"time"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -62,10 +63,10 @@ var (
Use: "run",
Short: "start NetBird Signal Server daemon",
SilenceUsage: true,
PreRunE: func(cmd *cobra.Command, args []string) error {
PreRun: func(cmd *cobra.Command, args []string) {
err := util.InitLog(logLevel, logFile)
if err != nil {
return fmt.Errorf("failed initializing log: %w", err)
log.Fatalf("failed initializing log %v", err)
}
flag.Parse()
@@ -86,8 +87,6 @@ var (
signalPort = 80
}
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse()
@@ -255,7 +254,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
}
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {

View File

@@ -2,41 +2,42 @@ package server
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/netip"
"sync"
"time"
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"github.com/netbirdio/netbird/util/wsproxy"
)
const (
bufferSize = 32 * 1024
ioTimeout = 5 * time.Second
dialTimeout = 10 * time.Second
bufferSize = 32 * 1024
)
// Config contains the configuration for the WebSocket proxy.
type Config struct {
Handler http.Handler
LocalGRPCAddr netip.AddrPort
Path string
MetricsRecorder MetricsRecorder
}
// Proxy handles WebSocket to gRPC handler proxying.
// Proxy handles WebSocket to TCP proxying for gRPC connections.
type Proxy struct {
config Config
metrics MetricsRecorder
}
// New creates a new WebSocket proxy instance with optional configuration
func New(handler http.Handler, opts ...Option) *Proxy {
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy {
config := Config{
Handler: handler,
LocalGRPCAddr: localGRPCAddr,
Path: wsproxy.ProxyPath,
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
}
@@ -62,7 +63,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
p.metrics.RecordConnection(ctx)
defer p.metrics.RecordDisconnection(ctx)
log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr)
log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr)
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
@@ -74,41 +75,71 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return
}
defer func() {
_ = wsConn.Close(websocket.StatusNormalClosure, "")
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil {
log.Debugf("Failed to close WebSocket: %v", err)
}
}()
clientConn, serverConn := net.Pipe()
log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr)
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
if err != nil {
p.metrics.RecordError(ctx, "tcp_dial_failed")
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
}
return
}
defer func() {
_ = clientConn.Close()
_ = serverConn.Close()
if err := tcpConn.Close(); err != nil {
log.Debugf("Failed to close TCP connection: %v", err)
}
}()
log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr)
log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr)
go func() {
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
Context: ctx,
Handler: p.config.Handler,
})
}()
p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
p.proxyData(ctx, wsConn, tcpConn)
}
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) {
proxyCtx, cancel := context.WithCancel(ctx)
defer cancel()
var wg sync.WaitGroup
wg.Add(2)
go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn)
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn)
wg.Wait()
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
log.Tracef("Proxy data transfer completed, both goroutines terminated")
case <-proxyCtx.Done():
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
log.Tracef("Error closing WebSocket during cancellation: %v", err)
}
if err := tcpConn.Close(); err != nil {
log.Tracef("Error closing TCP connection during cancellation: %v", err)
}
select {
case <-done:
log.Tracef("Goroutines terminated after forced connection closure")
case <-time.After(2 * time.Second):
log.Tracef("Goroutines did not terminate within timeout after connection closure")
}
}
}
func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
defer wg.Done()
defer cancel()
@@ -117,73 +148,80 @@ func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *syn
if err != nil {
switch {
case ctx.Err() != nil:
log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr)
case websocket.CloseStatus(err) != -1:
log.Debugf("WebSocket from %s disconnected", clientAddr)
log.Debugf("wsToTCP goroutine terminating due to context cancellation")
case websocket.CloseStatus(err) == websocket.StatusNormalClosure:
log.Debugf("WebSocket closed normally")
default:
p.metrics.RecordError(ctx, "websocket_read_error")
log.Debugf("WebSocket read error from %s: %v", clientAddr, err)
log.Errorf("WebSocket read error: %v", err)
}
return
}
if msgType != websocket.MessageBinary {
log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType)
log.Warnf("Unexpected WebSocket message type: %v", msgType)
continue
}
if ctx.Err() != nil {
log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write")
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write")
return
}
if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil {
log.Debugf("Failed to set pipe write deadline: %v", err)
if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Debugf("Failed to set TCP write deadline: %v", err)
}
n, err := pipeConn.Write(data)
n, err := tcpConn.Write(data)
if err != nil {
p.metrics.RecordError(ctx, "pipe_write_error")
log.Warnf("Pipe write error for %s: %v", clientAddr, err)
p.metrics.RecordError(ctx, "tcp_write_error")
log.Errorf("TCP write error: %v", err)
return
}
p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n))
p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n))
}
}
func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
defer wg.Done()
defer cancel()
buf := make([]byte, bufferSize)
for {
n, err := pipeConn.Read(buf)
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Debugf("Failed to set TCP read deadline: %v", err)
}
n, err := tcpConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
log.Tracef("pipeToWS goroutine terminating due to context cancellation")
log.Tracef("tcpToWS goroutine terminating due to context cancellation")
return
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue
}
if err != io.EOF {
log.Debugf("Pipe read error for %s: %v", clientAddr, err)
log.Errorf("TCP read error: %v", err)
}
return
}
if ctx.Err() != nil {
log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write")
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write")
return
}
if n > 0 {
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
p.metrics.RecordError(ctx, "websocket_write_error")
log.Warnf("WebSocket write error for %s: %v", clientAddr, err)
return
}
p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
p.metrics.RecordError(ctx, "websocket_write_error")
log.Errorf("WebSocket write error: %v", err)
return
}
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n))
}
}

View File

@@ -1,13 +1,9 @@
package version
import (
"golang.org/x/sys/windows/registry"
"runtime"
)
import "golang.org/x/sys/windows/registry"
const (
urlWinExe = "https://pkgs.netbird.io/windows/x64"
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
)
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
@@ -15,14 +11,9 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
// DownloadUrl return with the proper download link
func DownloadUrl() string {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
if err != nil {
if err == nil {
return urlWinExe
} else {
return downloadURL
}
url := urlWinExe
if runtime.GOARCH == "arm64" {
url = urlWinExeArm
}
return url
}