Compare commits

..

1 Commits

31 changed files with 183 additions and 711 deletions

View File

@@ -29,8 +29,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
// for js, the outer websocket layer takes care of tls if tlsEnabled {
if tlsEnabled && runtime.GOOS != "js" {
certPool, err := x509.SystemCertPool() certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil { if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err) log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
@@ -38,7 +37,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
} }
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{ transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
RootCAs: certPool, // for js, outer websocket layer takes care of tls verification via WithCustomDialer
InsecureSkipVerify: runtime.GOOS == "js",
RootCAs: certPool,
})) }))
} }

View File

@@ -73,44 +73,6 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
}
// Get the existing peer to preserve its allowed IPs
existingPeer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return fmt.Errorf("get peer: %w", err)
}
removePeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
}
//Re-add the peer without the endpoint but same AllowedIPs
reAddPeerCfg := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
AllowedIPs: existingPeer.AllowedIPs,
ReplaceAllowedIPs: true,
}
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
return fmt.Errorf(
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
)
}
return nil
}
func (c *KernelConfigurer) RemovePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -106,67 +106,6 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
return nil return nil
} }
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return fmt.Errorf("parse peer key: %w", err)
}
ipcStr, err := c.device.IpcGet()
if err != nil {
return fmt.Errorf("get IPC config: %w", err)
}
// Parse current status to get allowed IPs for the peer
stats, err := parseStatus(c.deviceName, ipcStr)
if err != nil {
return fmt.Errorf("parse IPC config: %w", err)
}
var allowedIPs []net.IPNet
found := false
for _, peer := range stats.Peers {
if peer.PublicKey == peerKey {
allowedIPs = peer.AllowedIPs
found = true
break
}
}
if !found {
return fmt.Errorf("peer %s not found", peerKey)
}
// remove the peer from the WireGuard configuration
peer := wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
Remove: true,
}
config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return fmt.Errorf("failed to remove peer: %s", ipcErr)
}
// Build the peer config
peer = wgtypes.PeerConfig{
PublicKey: peerKeyParsed,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
config = wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer},
}
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
return fmt.Errorf("remove endpoint address: %w", err)
}
return nil
}
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {

View File

@@ -21,5 +21,4 @@ type WGConfigurer interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time LastActivities() map[string]monotime.Time
RemoveEndpointAddress(peerKey string) error
} }

View File

@@ -148,17 +148,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing endpoint address: %s", peerKey)
return w.configurer.RemoveEndpointAddress(peerKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()

View File

@@ -240,17 +240,15 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains { for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain} singleDomain := []string{domain}
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
} }
if r.gpo { if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err) return i, fmt.Errorf("configure gpo DNS policy: %w", err)
} }
} }

View File

@@ -1,78 +0,0 @@
package dnsfwd
import (
"net/netip"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
)
type cache struct {
mu sync.RWMutex
records map[string]*cacheEntry
}
type cacheEntry struct {
ip4Addrs []netip.Addr
ip6Addrs []netip.Addr
}
func newCache() *cache {
return &cache{
records: make(map[string]*cacheEntry),
}
}
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.records[normalizeDomain(domain)]
if !exists {
return nil, false
}
switch reqType {
case dns.TypeA:
return slices.Clone(entry.ip4Addrs), true
case dns.TypeAAAA:
return slices.Clone(entry.ip6Addrs), true
default:
return nil, false
}
}
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
c.mu.Lock()
defer c.mu.Unlock()
norm := normalizeDomain(domain)
entry, exists := c.records[norm]
if !exists {
entry = &cacheEntry{}
c.records[norm] = entry
}
switch reqType {
case dns.TypeA:
entry.ip4Addrs = slices.Clone(addrs)
case dns.TypeAAAA:
entry.ip6Addrs = slices.Clone(addrs)
}
}
// unset removes cached entries for the given domain and request type.
func (c *cache) unset(domain string) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.records, normalizeDomain(domain))
}
// normalizeDomain converts an input domain into a canonical form used as cache key:
// lowercase and fully-qualified (with trailing dot).
func normalizeDomain(domain string) string {
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
return dns.Fqdn(strings.ToLower(domain))
}

View File

@@ -1,86 +0,0 @@
package dnsfwd
import (
"net/netip"
"testing"
)
func mustAddr(t *testing.T, s string) netip.Addr {
t.Helper()
a, err := netip.ParseAddr(s)
if err != nil {
t.Fatalf("parse addr %s: %v", s, err)
}
return a
}
func TestCacheNormalization(t *testing.T) {
c := newCache()
// Mixed case, without trailing dot
domainInput := "ExAmPlE.CoM"
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
// Lookup with lower, with trailing dot
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
}
// Lookup with different casing again
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
}
}
func TestCacheSeparateTypes(t *testing.T) {
c := newCache()
domain := "test.local"
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
c.set(domain, 1 /* A */, ipv4)
c.set(domain, 28 /* AAAA */, ipv6)
got4, ok4 := c.get(domain, 1)
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
}
got6, ok6 := c.get(domain, 28)
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
}
}
func TestCacheCloneOnGetAndSet(t *testing.T) {
c := newCache()
domain := "clone.test"
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
c.set(domain, 1, src)
// Mutate source slice; cache should be unaffected
src[0] = mustAddr(t, "9.9.9.9")
got, ok := c.get(domain, 1)
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
}
// Mutate returned slice; internal cache should remain unchanged
got[0] = mustAddr(t, "4.4.4.4")
got2, ok2 := c.get(domain, 1)
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
}
}
func TestCacheMiss(t *testing.T) {
c := newCache()
if got, ok := c.get("missing.example", 1); ok || got != nil {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
}
}

View File

@@ -46,7 +46,6 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewaller firewall firewaller
resolver resolver resolver resolver
cache *cache
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -57,7 +56,6 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
cache: newCache(),
} }
} }
@@ -105,39 +103,10 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
// remove cache entries for domains that no longer appear
f.removeStaleCacheEntries(f.fwdEntries, entries)
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated DNS forwarder with %d domains", len(entries)) log.Debugf("Updated DNS forwarder with %d domains", len(entries))
} }
// removeStaleCacheEntries unsets cache items for domains that were present
// in the old list but not present in the new list.
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
if f.cache == nil {
return
}
newSet := make(map[string]struct{}, len(newEntries))
for _, e := range newEntries {
if e == nil {
continue
}
newSet[e.Domain.PunycodeString()] = struct{}{}
}
for _, e := range oldEntries {
if e == nil {
continue
}
pattern := e.Domain.PunycodeString()
if _, ok := newSet[pattern]; !ok {
f.cache.unset(pattern)
}
}
}
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
var result *multierror.Error var result *multierror.Error
@@ -202,7 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
f.cache.set(domain, question.Qtype, ips)
return resp return resp
} }
@@ -314,69 +282,29 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
resp.Rcode = dns.RcodeSuccess resp.Rcode = dns.RcodeSuccess
} }
// handleDNSError processes DNS lookup errors and sends an appropriate error response. // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError( func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
ctx context.Context,
w dns.ResponseWriter,
question dns.Question,
resp *dns.Msg,
domain string,
err error,
) {
// Default to SERVFAIL; override below when appropriate.
resp.Rcode = dns.RcodeServerFailure
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
// Prefer typed DNS errors; fall back to generic logging otherwise.
var dnsErr *net.DNSError var dnsErr *net.DNSError
if !errors.As(err, &dnsErr) {
log.Warnf(errResolveFailed, domain, err)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
return
}
// NotFound: set NXDOMAIN / appropriate code via helper. switch {
if dnsErr.IsNotFound { case errors.As(err, &dnsErr):
f.setResponseCodeForNotFound(ctx, resp, domain, qType) resp.Rcode = dns.RcodeServerFailure
if writeErr := w.WriteMsg(resp); writeErr != nil { if dnsErr.IsNotFound {
log.Errorf("failed to write failure DNS response: %v", writeErr) f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
} }
f.cache.set(domain, question.Qtype, nil)
return
}
// Upstream failed but we might have a cached answer—serve it if present. if dnsErr.Server != "" {
if ips, ok := f.cache.get(domain, qType); ok { log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
if len(ips) > 0 { } else {
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) log.Warnf(errResolveFailed, domain, err)
f.addIPsToResponse(resp, domain, ips)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write cached DNS response: %v", writeErr)
}
} else { // send NXDOMAIN / appropriate code if cache is empty
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
if writeErr := w.WriteMsg(resp); writeErr != nil {
log.Errorf("failed to write failure DNS response: %v", writeErr)
}
} }
return default:
} resp.Rcode = dns.RcodeServerFailure
// No cache. Log with or without the server field for more context.
if dnsErr.Server != "" {
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }
// Write final failure response. if err := w.WriteMsg(resp); err != nil {
if writeErr := w.WriteMsg(resp); writeErr != nil { log.Errorf("failed to write failure DNS response: %v", err)
log.Errorf("failed to write failure DNS response: %v", writeErr)
} }
} }

View File

@@ -648,95 +648,6 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
} }
// Ensures that when the first query succeeds and populates the cache,
// a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("1.2.3.4")
// First call resolves successfully and populates cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{ip}, nil).Once()
// Second call fails upstream; forwarder should serve from cache
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
// First query: populate cache
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
// Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("ExAmPlE.CoM")
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
forwarder.UpdateDomains(entries)
ip := netip.MustParseAddr("9.8.7.6")
// Initial resolution with mixed case to populate cache
mixedQuery := "ExAmPlE.CoM"
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
Return([]netip.Addr{ip}, nil).Once()
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(w1, q1)
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
// Subsequent query without dot and upper case should hit cache even if upstream fails
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(w2, q2)
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios // Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{} mockFirewall := &MockFirewall{}

View File

@@ -105,10 +105,6 @@ type MockWGIface struct {
LastActivitiesFunc func() map[string]monotime.Time LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
return nil
}
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented") return nil, fmt.Errorf("not implemented")
} }

View File

@@ -28,7 +28,6 @@ type wgIfaceBase interface {
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemoveEndpointAddress(key string) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error

View File

@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() { if !isForceRelayed() {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
} }
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher) conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
@@ -430,9 +430,6 @@ func (conn *Conn) onICEStateDisconnected() {
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -526,9 +523,6 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {

View File

@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
onNewOfferChan := make(chan struct{}) onNewOffeChan := make(chan struct{})
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOfferChan <- struct{}{} onNewOffeChan <- struct{}{}
}) })
conn.OnRemoteOffer(OfferAnswer{ conn.OnRemoteOffer(OfferAnswer{
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOfferChan: case <-onNewOffeChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
onNewOfferChan := make(chan struct{}) onNewOffeChan := make(chan struct{})
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
onNewOfferChan <- struct{}{} onNewOffeChan <- struct{}{}
}) })
conn.OnRemoteAnswer(OfferAnswer{ conn.OnRemoteAnswer(OfferAnswer{
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
defer cancel() defer cancel()
select { select {
case <-onNewOfferChan: case <-onNewOffeChan:
// success // success
case <-ctx.Done(): case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out") t.Error("expected to receive a new offer notification, but timed out")

View File

@@ -1,20 +0,0 @@
package guard
import (
"os"
"strconv"
"time"
)
const (
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
)
func GetICEMonitorPeriod() time.Duration {
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return defaultCandidatesMonitorPeriod
}

View File

@@ -16,8 +16,8 @@ import (
) )
const ( const (
defaultCandidatesMonitorPeriod = 5 * time.Minute candidatesMonitorPeriod = 5 * time.Minute
candidateGatheringTimeout = 5 * time.Second candidateGatheringTimeout = 5 * time.Second
) )
type ICEMonitor struct { type ICEMonitor struct {
@@ -25,19 +25,16 @@ type ICEMonitor struct {
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig icemaker.Config iceConfig icemaker.Config
tickerPeriod time.Duration
currentCandidatesAddress []string currentCandidatesAddress []string
candidatesMu sync.Mutex candidatesMu sync.Mutex
} }
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor { func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
log.Debugf("prepare ICE monitor with period: %s", period)
cm := &ICEMonitor{ cm := &ICEMonitor{
ReconnectCh: make(chan struct{}, 1), ReconnectCh: make(chan struct{}, 1),
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
iceConfig: config, iceConfig: config,
tickerPeriod: period,
} }
return cm return cm
} }
@@ -49,12 +46,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
return return
} }
// Initial check to populate the candidates for later comparison ticker := time.NewTicker(candidatesMonitorPeriod)
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
log.Warnf("Failed to check initial ICE candidates: %v", err)
}
ticker := time.NewTicker(cm.tickerPeriod)
defer ticker.Stop() defer ticker.Stop()
for { for {

View File

@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
go iceMonitor.Start(ctx, w.onICEChanged) go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -44,19 +44,13 @@ type OfferAnswer struct {
} }
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
config ConnConfig config ConnConfig
signaler *Signaler signaler *Signaler
ice *WorkerICE ice *WorkerICE
relay *WorkerRelay relay *WorkerRelay
// relayListener is not blocking because the listener is using a goroutine to process the messages onNewOfferListeners []*OfferListener
// and it will only keep the latest message if multiple offers are received in a short time
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
// and also to avoid processing old offers if multiple offers are received in a short time
// the listener will always process the latest offer
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
@@ -76,39 +70,28 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
} }
} }
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.relayListener = NewAsyncOfferListener(offer) l := NewOfferListener(offer)
} h.onNewOfferListeners = append(h.onNewOfferListeners, l)
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
h.iceListener = offer
} }
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) // received confirmation from the remote peer -> ready to proceed
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
if err := h.sendAnswer(); err != nil { if err := h.sendAnswer(); err != nil {
h.log.Errorf("failed to send remote offer confirmation: %s", err) h.log.Errorf("failed to send remote offer confirmation: %s", err)
continue continue
} }
for _, listener := range h.onNewOfferListeners {
listener.Notify(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
if h.relayListener != nil { for _, listener := range h.onNewOfferListeners {
h.relayListener.Notify(&remoteOfferAnswer) listener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers") h.log.Infof("stop listening for remote offers and answers")

View File

@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
return oa.SessionID.String() return oa.SessionID.String()
} }
type AsyncOfferListener struct { type OfferListener struct {
fn callbackFunc fn callbackFunc
running bool running bool
latest *OfferAnswer latest *OfferAnswer
mu sync.Mutex mu sync.Mutex
} }
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener { func NewOfferListener(fn callbackFunc) *OfferListener {
return &AsyncOfferListener{ return &OfferListener{
fn: fn, fn: fn,
} }
} }
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) { func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
o.mu.Lock() o.mu.Lock()
defer o.mu.Unlock() defer o.mu.Unlock()

View File

@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
runChan <- struct{}{} runChan <- struct{}{}
} }
hl := NewAsyncOfferListener(longRunningFn) hl := NewOfferListener(longRunningFn)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)
hl.Notify(dummyOfferAnswer) hl.Notify(dummyOfferAnswer)

View File

@@ -18,5 +18,4 @@ type WGIface interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address Address() wgaddr.Address
RemoveEndpointAddress(key string) error
} }

View File

@@ -92,16 +92,23 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent != nil || w.agentConnecting { if w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil {
// backward compatibility with old clients that do not send session ID // backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil { if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock()
return return
} }
if w.remoteSessionID == *remoteOfferAnswer.SessionID { if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString()) w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return return
} }
w.log.Debugf("agent already exists, recreate the connection") w.log.Debugf("agent already exists, recreate the connection")
@@ -109,12 +116,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
if err := w.agent.Close(); err != nil { if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
} }
@@ -125,23 +126,18 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
preferredCandidateTypes = icemaker.CandidateTypes() preferredCandidateTypes = icemaker.CandidateTypes()
} }
if remoteOfferAnswer.SessionID != nil { w.log.Debugf("recreate ICE agent")
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
}
dialerCtx, dialerCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock()
return return
} }
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel w.agentDialerCancel = dialerCancel
w.agentConnecting = true w.agentConnecting = true
if remoteOfferAnswer.SessionID != nil { w.muxAgent.Unlock()
w.remoteSessionID = *remoteOfferAnswer.SessionID
} else {
w.remoteSessionID = ""
}
go w.connect(dialerCtx, agent, remoteOfferAnswer) go w.connect(dialerCtx, agent, remoteOfferAnswer)
} }
@@ -297,6 +293,9 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
w.muxAgent.Lock() w.muxAgent.Lock()
w.agentConnecting = false w.agentConnecting = false
w.lastSuccess = time.Now() w.lastSuccess = time.Now()
if remoteOfferAnswer.SessionID != nil {
w.remoteSessionID = *remoteOfferAnswer.SessionID
}
w.muxAgent.Unlock() w.muxAgent.Unlock()
// todo: the potential problem is a race between the onConnectionStateChange // todo: the potential problem is a race between the onConnectionStateChange
@@ -310,17 +309,16 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
} }
w.muxAgent.Lock() w.muxAgent.Lock()
// todo review does it make sense to generate new session ID all the time when w.agent==agent
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent { if w.agent == agent {
// consider to remove from here and move to the OnNewOffer
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
w.agent = nil w.agent = nil
w.agentConnecting = false w.agentConnecting = false
w.remoteSessionID = ""
} }
w.muxAgent.Unlock() w.muxAgent.Unlock()
} }
@@ -397,12 +395,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
// notify the conn.onICEStateDisconnected changes to update the current used priority // notify the conn.onICEStateDisconnected changes to update the current used priority
w.closeAgent(agent, dialerCancel)
if w.lastKnownState == ice.ConnectionStateConnected { if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected() w.conn.onICEStateDisconnected()
} }
w.closeAgent(agent, dialerCancel)
default: default:
return return
} }

View File

@@ -1354,13 +1354,7 @@ func (s *serviceClient) updateConfig() error {
} }
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
// It also starts a background goroutine that periodically checks if the client is already connected func (s *serviceClient) showLoginURL() {
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
// also cancelled when the window is closed.
func (s *serviceClient) showLoginURL() context.CancelFunc {
// create a cancellable context for the background check goroutine
ctx, cancel := context.WithCancel(s.ctx)
resIcon := fyne.NewStaticResource("netbird.png", iconAbout) resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
@@ -1369,8 +1363,6 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
s.wLoginURL.Resize(fyne.NewSize(400, 200)) s.wLoginURL.Resize(fyne.NewSize(400, 200))
s.wLoginURL.SetIcon(resIcon) s.wLoginURL.SetIcon(resIcon)
} }
// ensure goroutine is cancelled when the window is closed
s.wLoginURL.SetOnClosed(func() { cancel() })
// add a description label // add a description label
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.") label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
@@ -1451,39 +1443,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
) )
s.wLoginURL.SetContent(container.NewCenter(content)) s.wLoginURL.SetContent(container.NewCenter(content))
// start a goroutine to check connection status and close the window if connected
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
continue
}
if status.Status == string(internal.StatusConnected) {
if s.wLoginURL != nil {
s.wLoginURL.Close()
}
return
}
}
}
}()
s.wLoginURL.Show() s.wLoginURL.Show()
// return cancel func so callers can stop the background goroutine if desired
return cancel
} }
func openURL(url string) error { func openURL(url string) error {

View File

@@ -47,7 +47,7 @@ services:
- traefik.enable=true - traefik.enable=true
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`) - traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal - traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80 - traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`) - traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000 - traefik.http.services.netbird-signal.loadbalancer.server.port=10000
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c

View File

@@ -621,7 +621,7 @@ renderCaddyfile() {
# relay # relay
reverse_proxy /relay* relay:80 reverse_proxy /relay* relay:80
# Signal # Signal
reverse_proxy /ws-proxy/signal* signal:80 reverse_proxy /ws-proxy/signal* signal:10000
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management # Management
reverse_proxy /api/* management:80 reverse_proxy /api/* management:80

View File

@@ -20,10 +20,6 @@ upstream management {
# insert the grpc+http port of your management container here # insert the grpc+http port of your management container here
server 127.0.0.1:8012; server 127.0.0.1:8012;
} }
upstream relay {
# insert the port of your relay container here
server 127.0.0.1:33080;
}
server { server {
# HTTP server config # HTTP server config
@@ -59,10 +55,6 @@ server {
# Proxy Signal wsproxy endpoint # Proxy Signal wsproxy endpoint
location /ws-proxy/signal { location /ws-proxy/signal {
proxy_pass http://signal; proxy_pass http://signal;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "Upgrade";
proxy_set_header Host $host;
} }
# Proxy Signal # Proxy Signal
location /signalexchange.SignalExchange/ { location /signalexchange.SignalExchange/ {
@@ -79,10 +71,6 @@ server {
# Proxy Management wsproxy endpoint # Proxy Management wsproxy endpoint
location /ws-proxy/management { location /ws-proxy/management {
proxy_pass http://management; proxy_pass http://management;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "Upgrade";
proxy_set_header Host $host;
} }
# Proxy Management grpc endpoint # Proxy Management grpc endpoint
location /management.ManagementService/ { location /management.ManagementService/ {
@@ -92,14 +80,6 @@ server {
grpc_send_timeout 1d; grpc_send_timeout 1d;
grpc_socket_keepalive on; grpc_socket_keepalive on;
} }
# Proxy Relay
location /relay {
proxy_pass http://relay;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "Upgrade";
proxy_set_header Host $host;
}
ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem;
ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem; ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem;

View File

@@ -1,4 +1,4 @@
FROM ubuntu:24.04 FROM ubuntu:24.10
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"] ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
CMD ["--log-file", "console"] CMD ["--log-file", "console"]

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/netip"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -251,7 +252,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
} }
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler { func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
switch { switch {

View File

@@ -10,6 +10,7 @@ import (
"net/http" "net/http"
// nolint:gosec // nolint:gosec
_ "net/http/pprof" _ "net/http/pprof"
"net/netip"
"time" "time"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
@@ -62,10 +63,10 @@ var (
Use: "run", Use: "run",
Short: "start NetBird Signal Server daemon", Short: "start NetBird Signal Server daemon",
SilenceUsage: true, SilenceUsage: true,
PreRunE: func(cmd *cobra.Command, args []string) error { PreRun: func(cmd *cobra.Command, args []string) {
err := util.InitLog(logLevel, logFile) err := util.InitLog(logLevel, logFile)
if err != nil { if err != nil {
return fmt.Errorf("failed initializing log: %w", err) log.Fatalf("failed initializing log %v", err)
} }
flag.Parse() flag.Parse()
@@ -86,8 +87,6 @@ var (
signalPort = 80 signalPort = 80
} }
} }
return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
flag.Parse() flag.Parse()
@@ -255,7 +254,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
} }
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler { func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter)) wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch { switch {

View File

@@ -2,41 +2,42 @@ package server
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/coder/websocket" "github.com/coder/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"github.com/netbirdio/netbird/util/wsproxy" "github.com/netbirdio/netbird/util/wsproxy"
) )
const ( const (
bufferSize = 32 * 1024 dialTimeout = 10 * time.Second
ioTimeout = 5 * time.Second bufferSize = 32 * 1024
) )
// Config contains the configuration for the WebSocket proxy. // Config contains the configuration for the WebSocket proxy.
type Config struct { type Config struct {
Handler http.Handler LocalGRPCAddr netip.AddrPort
Path string Path string
MetricsRecorder MetricsRecorder MetricsRecorder MetricsRecorder
} }
// Proxy handles WebSocket to gRPC handler proxying. // Proxy handles WebSocket to TCP proxying for gRPC connections.
type Proxy struct { type Proxy struct {
config Config config Config
metrics MetricsRecorder metrics MetricsRecorder
} }
// New creates a new WebSocket proxy instance with optional configuration // New creates a new WebSocket proxy instance with optional configuration
func New(handler http.Handler, opts ...Option) *Proxy { func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy {
config := Config{ config := Config{
Handler: handler, LocalGRPCAddr: localGRPCAddr,
Path: wsproxy.ProxyPath, Path: wsproxy.ProxyPath,
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
} }
@@ -62,7 +63,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
p.metrics.RecordConnection(ctx) p.metrics.RecordConnection(ctx)
defer p.metrics.RecordDisconnection(ctx) defer p.metrics.RecordDisconnection(ctx)
log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr) log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr)
acceptOptions := &websocket.AcceptOptions{ acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, OriginPatterns: []string{"*"},
} }
@@ -74,41 +75,71 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
return return
} }
defer func() { defer func() {
_ = wsConn.Close(websocket.StatusNormalClosure, "") if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil {
log.Debugf("Failed to close WebSocket: %v", err)
}
}() }()
clientConn, serverConn := net.Pipe() log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr)
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
if err != nil {
p.metrics.RecordError(ctx, "tcp_dial_failed")
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
}
return
}
defer func() { defer func() {
_ = clientConn.Close() if err := tcpConn.Close(); err != nil {
_ = serverConn.Close() log.Debugf("Failed to close TCP connection: %v", err)
}
}() }()
log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr) log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr)
go func() { p.proxyData(ctx, wsConn, tcpConn)
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
Context: ctx,
Handler: p.config.Handler,
})
}()
p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
} }
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) {
proxyCtx, cancel := context.WithCancel(ctx) proxyCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn)
go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr) go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn)
wg.Wait() done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
log.Tracef("Proxy data transfer completed, both goroutines terminated")
case <-proxyCtx.Done():
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
log.Tracef("Error closing WebSocket during cancellation: %v", err)
}
if err := tcpConn.Close(); err != nil {
log.Tracef("Error closing TCP connection during cancellation: %v", err)
}
select {
case <-done:
log.Tracef("Goroutines terminated after forced connection closure")
case <-time.After(2 * time.Second):
log.Tracef("Goroutines did not terminate within timeout after connection closure")
}
}
} }
func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
@@ -117,73 +148,80 @@ func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *syn
if err != nil { if err != nil {
switch { switch {
case ctx.Err() != nil: case ctx.Err() != nil:
log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr) log.Debugf("wsToTCP goroutine terminating due to context cancellation")
case websocket.CloseStatus(err) != -1: case websocket.CloseStatus(err) == websocket.StatusNormalClosure:
log.Debugf("WebSocket from %s disconnected", clientAddr) log.Debugf("WebSocket closed normally")
default: default:
p.metrics.RecordError(ctx, "websocket_read_error") p.metrics.RecordError(ctx, "websocket_read_error")
log.Debugf("WebSocket read error from %s: %v", clientAddr, err) log.Errorf("WebSocket read error: %v", err)
} }
return return
} }
if msgType != websocket.MessageBinary { if msgType != websocket.MessageBinary {
log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType) log.Warnf("Unexpected WebSocket message type: %v", msgType)
continue continue
} }
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write") log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write")
return return
} }
if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil { if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Debugf("Failed to set pipe write deadline: %v", err) log.Debugf("Failed to set TCP write deadline: %v", err)
} }
n, err := pipeConn.Write(data) n, err := tcpConn.Write(data)
if err != nil { if err != nil {
p.metrics.RecordError(ctx, "pipe_write_error") p.metrics.RecordError(ctx, "tcp_write_error")
log.Warnf("Pipe write error for %s: %v", clientAddr, err) log.Errorf("TCP write error: %v", err)
return return
} }
p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n)) p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n))
} }
} }
func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) { func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
defer wg.Done() defer wg.Done()
defer cancel() defer cancel()
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
for { for {
n, err := pipeConn.Read(buf) if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Debugf("Failed to set TCP read deadline: %v", err)
}
n, err := tcpConn.Read(buf)
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("pipeToWS goroutine terminating due to context cancellation") log.Tracef("tcpToWS goroutine terminating due to context cancellation")
return return
} }
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
continue
}
if err != io.EOF { if err != io.EOF {
log.Debugf("Pipe read error for %s: %v", clientAddr, err) log.Errorf("TCP read error: %v", err)
} }
return return
} }
if ctx.Err() != nil { if ctx.Err() != nil {
log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write") log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write")
return return
} }
if n > 0 { if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil { p.metrics.RecordError(ctx, "websocket_write_error")
p.metrics.RecordError(ctx, "websocket_write_error") log.Errorf("WebSocket write error: %v", err)
log.Warnf("WebSocket write error for %s: %v", clientAddr, err) return
return
}
p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
} }
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n))
} }
} }

View File

@@ -1,13 +1,9 @@
package version package version
import ( import "golang.org/x/sys/windows/registry"
"golang.org/x/sys/windows/registry"
"runtime"
)
const ( const (
urlWinExe = "https://pkgs.netbird.io/windows/x64" urlWinExe = "https://pkgs.netbird.io/windows/x64"
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
) )
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird" var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
@@ -15,14 +11,9 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
// DownloadUrl return with the proper download link // DownloadUrl return with the proper download link
func DownloadUrl() string { func DownloadUrl() string {
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE) _, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
if err != nil { if err == nil {
return urlWinExe
} else {
return downloadURL return downloadURL
} }
url := urlWinExe
if runtime.GOARCH == "arm64" {
url = urlWinExeArm
}
return url
} }