mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 08:54:11 -04:00
Compare commits
1 Commits
v0.59.4
...
snyk-fix-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea465680af |
@@ -29,8 +29,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
// for js, the outer websocket layer takes care of tls
|
if tlsEnabled {
|
||||||
if tlsEnabled && runtime.GOOS != "js" {
|
|
||||||
certPool, err := x509.SystemCertPool()
|
certPool, err := x509.SystemCertPool()
|
||||||
if err != nil || certPool == nil {
|
if err != nil || certPool == nil {
|
||||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||||
@@ -38,7 +37,9 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}
|
}
|
||||||
|
|
||||||
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
|
||||||
RootCAs: certPool,
|
// for js, outer websocket layer takes care of tls verification via WithCustomDialer
|
||||||
|
InsecureSkipVerify: runtime.GOOS == "js",
|
||||||
|
RootCAs: certPool,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -73,44 +73,6 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemoveEndpointAddress(peerKey string) error {
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the existing peer to preserve its allowed IPs
|
|
||||||
existingPeer, err := c.getPeer(c.deviceName, peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get peer: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
removePeerCfg := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
Remove: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{removePeerCfg}}); err != nil {
|
|
||||||
return fmt.Errorf(`error removing peer %s from interface %s: %w`, peerKey, c.deviceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//Re-add the peer without the endpoint but same AllowedIPs
|
|
||||||
reAddPeerCfg := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
AllowedIPs: existingPeer.AllowedIPs,
|
|
||||||
ReplaceAllowedIPs: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.configure(wgtypes.Config{Peers: []wgtypes.PeerConfig{reAddPeerCfg}}); err != nil {
|
|
||||||
return fmt.Errorf(
|
|
||||||
`error re-adding peer %s to interface %s with allowed IPs %v: %w`,
|
|
||||||
peerKey, c.deviceName, existingPeer.AllowedIPs, err,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -106,67 +106,6 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemoveEndpointAddress(peerKey string) error {
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse peer key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ipcStr, err := c.device.IpcGet()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get IPC config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse current status to get allowed IPs for the peer
|
|
||||||
stats, err := parseStatus(c.deviceName, ipcStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse IPC config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var allowedIPs []net.IPNet
|
|
||||||
found := false
|
|
||||||
for _, peer := range stats.Peers {
|
|
||||||
if peer.PublicKey == peerKey {
|
|
||||||
allowedIPs = peer.AllowedIPs
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("peer %s not found", peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the peer from the WireGuard configuration
|
|
||||||
peer := wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
Remove: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
config := wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
|
||||||
}
|
|
||||||
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
|
||||||
return fmt.Errorf("failed to remove peer: %s", ipcErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the peer config
|
|
||||||
peer = wgtypes.PeerConfig{
|
|
||||||
PublicKey: peerKeyParsed,
|
|
||||||
ReplaceAllowedIPs: true,
|
|
||||||
AllowedIPs: allowedIPs,
|
|
||||||
}
|
|
||||||
|
|
||||||
config = wgtypes.Config{
|
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.device.IpcSet(toWgUserspaceString(config)); err != nil {
|
|
||||||
return fmt.Errorf("remove endpoint address: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -21,5 +21,4 @@ type WGConfigurer interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
FullStats() (*configurer.Stats, error)
|
FullStats() (*configurer.Stats, error)
|
||||||
LastActivities() map[string]monotime.Time
|
LastActivities() map[string]monotime.Time
|
||||||
RemoveEndpointAddress(peerKey string) error
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -148,17 +148,6 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
|
|||||||
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) RemoveEndpointAddress(peerKey string) error {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
if w.configurer == nil {
|
|
||||||
return ErrIfaceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Removing endpoint address: %s", peerKey)
|
|
||||||
return w.configurer.RemoveEndpointAddress(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
func (w *WGIface) RemovePeer(peerKey string) error {
|
func (w *WGIface) RemovePeer(peerKey string) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
|
|||||||
@@ -240,17 +240,15 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
for i, domain := range domains {
|
for i, domain := range domains {
|
||||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
|
||||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
|
||||||
|
|
||||||
singleDomain := []string{domain}
|
singleDomain := []string{domain}
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, singleDomain, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
package dnsfwd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
|
||||||
|
|
||||||
type cache struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
records map[string]*cacheEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
type cacheEntry struct {
|
|
||||||
ip4Addrs []netip.Addr
|
|
||||||
ip6Addrs []netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCache() *cache {
|
|
||||||
return &cache{
|
|
||||||
records: make(map[string]*cacheEntry),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cache) get(domain string, reqType uint16) ([]netip.Addr, bool) {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
entry, exists := c.records[normalizeDomain(domain)]
|
|
||||||
if !exists {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
switch reqType {
|
|
||||||
case dns.TypeA:
|
|
||||||
return slices.Clone(entry.ip4Addrs), true
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
return slices.Clone(entry.ip6Addrs), true
|
|
||||||
default:
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *cache) set(domain string, reqType uint16, addrs []netip.Addr) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
norm := normalizeDomain(domain)
|
|
||||||
entry, exists := c.records[norm]
|
|
||||||
if !exists {
|
|
||||||
entry = &cacheEntry{}
|
|
||||||
c.records[norm] = entry
|
|
||||||
}
|
|
||||||
|
|
||||||
switch reqType {
|
|
||||||
case dns.TypeA:
|
|
||||||
entry.ip4Addrs = slices.Clone(addrs)
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
entry.ip6Addrs = slices.Clone(addrs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// unset removes cached entries for the given domain and request type.
|
|
||||||
func (c *cache) unset(domain string) {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
delete(c.records, normalizeDomain(domain))
|
|
||||||
}
|
|
||||||
|
|
||||||
// normalizeDomain converts an input domain into a canonical form used as cache key:
|
|
||||||
// lowercase and fully-qualified (with trailing dot).
|
|
||||||
func normalizeDomain(domain string) string {
|
|
||||||
// dns.Fqdn ensures trailing dot; ToLower for consistent casing
|
|
||||||
return dns.Fqdn(strings.ToLower(domain))
|
|
||||||
}
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
package dnsfwd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func mustAddr(t *testing.T, s string) netip.Addr {
|
|
||||||
t.Helper()
|
|
||||||
a, err := netip.ParseAddr(s)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse addr %s: %v", s, err)
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheNormalization(t *testing.T) {
|
|
||||||
c := newCache()
|
|
||||||
|
|
||||||
// Mixed case, without trailing dot
|
|
||||||
domainInput := "ExAmPlE.CoM"
|
|
||||||
ipv4 := []netip.Addr{mustAddr(t, "1.2.3.4")}
|
|
||||||
c.set(domainInput, 1 /* dns.TypeA */, ipv4)
|
|
||||||
|
|
||||||
// Lookup with lower, with trailing dot
|
|
||||||
if got, ok := c.get("example.com.", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
|
||||||
t.Fatalf("expected cached IPv4 result via normalized key, got=%v ok=%v", got, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lookup with different casing again
|
|
||||||
if got, ok := c.get("EXAMPLE.COM", 1); !ok || len(got) != 1 || got[0].String() != "1.2.3.4" {
|
|
||||||
t.Fatalf("expected cached IPv4 result via different casing, got=%v ok=%v", got, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheSeparateTypes(t *testing.T) {
|
|
||||||
c := newCache()
|
|
||||||
|
|
||||||
domain := "test.local"
|
|
||||||
ipv4 := []netip.Addr{mustAddr(t, "10.0.0.1")}
|
|
||||||
ipv6 := []netip.Addr{mustAddr(t, "2001:db8::1")}
|
|
||||||
|
|
||||||
c.set(domain, 1 /* A */, ipv4)
|
|
||||||
c.set(domain, 28 /* AAAA */, ipv6)
|
|
||||||
|
|
||||||
got4, ok4 := c.get(domain, 1)
|
|
||||||
if !ok4 || len(got4) != 1 || got4[0] != ipv4[0] {
|
|
||||||
t.Fatalf("expected A record from cache, got=%v ok=%v", got4, ok4)
|
|
||||||
}
|
|
||||||
|
|
||||||
got6, ok6 := c.get(domain, 28)
|
|
||||||
if !ok6 || len(got6) != 1 || got6[0] != ipv6[0] {
|
|
||||||
t.Fatalf("expected AAAA record from cache, got=%v ok=%v", got6, ok6)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheCloneOnGetAndSet(t *testing.T) {
|
|
||||||
c := newCache()
|
|
||||||
domain := "clone.test"
|
|
||||||
|
|
||||||
src := []netip.Addr{mustAddr(t, "8.8.8.8")}
|
|
||||||
c.set(domain, 1, src)
|
|
||||||
|
|
||||||
// Mutate source slice; cache should be unaffected
|
|
||||||
src[0] = mustAddr(t, "9.9.9.9")
|
|
||||||
|
|
||||||
got, ok := c.get(domain, 1)
|
|
||||||
if !ok || len(got) != 1 || got[0].String() != "8.8.8.8" {
|
|
||||||
t.Fatalf("expected cached value to be independent of source slice, got=%v ok=%v", got, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mutate returned slice; internal cache should remain unchanged
|
|
||||||
got[0] = mustAddr(t, "4.4.4.4")
|
|
||||||
got2, ok2 := c.get(domain, 1)
|
|
||||||
if !ok2 || len(got2) != 1 || got2[0].String() != "8.8.8.8" {
|
|
||||||
t.Fatalf("expected returned slice to be a clone, got=%v ok=%v", got2, ok2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCacheMiss(t *testing.T) {
|
|
||||||
c := newCache()
|
|
||||||
if got, ok := c.get("missing.example", 1); ok || got != nil {
|
|
||||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -46,7 +46,6 @@ type DNSForwarder struct {
|
|||||||
fwdEntries []*ForwarderEntry
|
fwdEntries []*ForwarderEntry
|
||||||
firewall firewaller
|
firewall firewaller
|
||||||
resolver resolver
|
resolver resolver
|
||||||
cache *cache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
@@ -57,7 +56,6 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
|||||||
firewall: firewall,
|
firewall: firewall,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
resolver: net.DefaultResolver,
|
resolver: net.DefaultResolver,
|
||||||
cache: newCache(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,39 +103,10 @@ func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
|||||||
f.mutex.Lock()
|
f.mutex.Lock()
|
||||||
defer f.mutex.Unlock()
|
defer f.mutex.Unlock()
|
||||||
|
|
||||||
// remove cache entries for domains that no longer appear
|
|
||||||
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
|
||||||
|
|
||||||
f.fwdEntries = entries
|
f.fwdEntries = entries
|
||||||
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeStaleCacheEntries unsets cache items for domains that were present
|
|
||||||
// in the old list but not present in the new list.
|
|
||||||
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
|
||||||
if f.cache == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
newSet := make(map[string]struct{}, len(newEntries))
|
|
||||||
for _, e := range newEntries {
|
|
||||||
if e == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
newSet[e.Domain.PunycodeString()] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range oldEntries {
|
|
||||||
if e == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pattern := e.Domain.PunycodeString()
|
|
||||||
if _, ok := newSet[pattern]; !ok {
|
|
||||||
f.cache.unset(pattern)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *DNSForwarder) Close(ctx context.Context) error {
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -202,7 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
|||||||
|
|
||||||
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
f.cache.set(domain, question.Qtype, ips)
|
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
@@ -314,69 +282,29 @@ func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns
|
|||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
func (f *DNSForwarder) handleDNSError(
|
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||||
ctx context.Context,
|
|
||||||
w dns.ResponseWriter,
|
|
||||||
question dns.Question,
|
|
||||||
resp *dns.Msg,
|
|
||||||
domain string,
|
|
||||||
err error,
|
|
||||||
) {
|
|
||||||
// Default to SERVFAIL; override below when appropriate.
|
|
||||||
resp.Rcode = dns.RcodeServerFailure
|
|
||||||
|
|
||||||
qType := question.Qtype
|
|
||||||
qTypeName := dns.TypeToString[qType]
|
|
||||||
|
|
||||||
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if !errors.As(err, &dnsErr) {
|
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotFound: set NXDOMAIN / appropriate code via helper.
|
switch {
|
||||||
if dnsErr.IsNotFound {
|
case errors.As(err, &dnsErr):
|
||||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
if dnsErr.IsNotFound {
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||||
}
|
}
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Upstream failed but we might have a cached answer—serve it if present.
|
if dnsErr.Server != "" {
|
||||||
if ips, ok := f.cache.get(domain, qType); ok {
|
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||||
if len(ips) > 0 {
|
} else {
|
||||||
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
f.addIPsToResponse(resp, domain, ips)
|
|
||||||
resp.Rcode = dns.RcodeSuccess
|
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
} else { // send NXDOMAIN / appropriate code if cache is empty
|
|
||||||
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return
|
default:
|
||||||
}
|
resp.Rcode = dns.RcodeServerFailure
|
||||||
|
|
||||||
// No cache. Log with or without the server field for more context.
|
|
||||||
if dnsErr.Server != "" {
|
|
||||||
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
|
||||||
} else {
|
|
||||||
log.Warnf(errResolveFailed, domain, err)
|
log.Warnf(errResolveFailed, domain, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -648,95 +648,6 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
|||||||
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensures that when the first query succeeds and populates the cache,
|
|
||||||
// a subsequent upstream failure still returns a successful response from cache.
|
|
||||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|
||||||
mockResolver := &MockResolver{}
|
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
|
||||||
forwarder.resolver = mockResolver
|
|
||||||
|
|
||||||
d, err := domain.FromString("example.com")
|
|
||||||
require.NoError(t, err)
|
|
||||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-cache"}}
|
|
||||||
forwarder.UpdateDomains(entries)
|
|
||||||
|
|
||||||
ip := netip.MustParseAddr("1.2.3.4")
|
|
||||||
|
|
||||||
// First call resolves successfully and populates cache
|
|
||||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
|
||||||
Return([]netip.Addr{ip}, nil).Once()
|
|
||||||
|
|
||||||
// Second call fails upstream; forwarder should serve from cache
|
|
||||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn("example.com")).
|
|
||||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
|
||||||
|
|
||||||
// First query: populate cache
|
|
||||||
q1 := &dns.Msg{}
|
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
|
||||||
w1 := &test.MockResponseWriter{}
|
|
||||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
|
||||||
require.NotNil(t, resp1)
|
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
|
||||||
require.Len(t, resp1.Answer, 1)
|
|
||||||
|
|
||||||
// Second query: serve from cache after upstream failure
|
|
||||||
q2 := &dns.Msg{}
|
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
|
||||||
var writtenResp *dns.Msg
|
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
|
||||||
_ = forwarder.handleDNSQuery(w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
|
||||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|
||||||
mockResolver := &MockResolver{}
|
|
||||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
|
||||||
forwarder.resolver = mockResolver
|
|
||||||
|
|
||||||
d, err := domain.FromString("ExAmPlE.CoM")
|
|
||||||
require.NoError(t, err)
|
|
||||||
entries := []*ForwarderEntry{{Domain: d, ResID: "res-norm"}}
|
|
||||||
forwarder.UpdateDomains(entries)
|
|
||||||
|
|
||||||
ip := netip.MustParseAddr("9.8.7.6")
|
|
||||||
|
|
||||||
// Initial resolution with mixed case to populate cache
|
|
||||||
mixedQuery := "ExAmPlE.CoM"
|
|
||||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(strings.ToLower(mixedQuery))).
|
|
||||||
Return([]netip.Addr{ip}, nil).Once()
|
|
||||||
|
|
||||||
q1 := &dns.Msg{}
|
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
|
||||||
w1 := &test.MockResponseWriter{}
|
|
||||||
resp1 := forwarder.handleDNSQuery(w1, q1)
|
|
||||||
require.NotNil(t, resp1)
|
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
|
||||||
require.Len(t, resp1.Answer, 1)
|
|
||||||
|
|
||||||
// Subsequent query without dot and upper case should hit cache even if upstream fails
|
|
||||||
// Forwarder lowercases and uses the question name as-is (no trailing dot here)
|
|
||||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", strings.ToLower("EXAMPLE.COM")).
|
|
||||||
Return([]netip.Addr{}, &net.DNSError{Err: "temporary failure"}).Once()
|
|
||||||
|
|
||||||
q2 := &dns.Msg{}
|
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
|
||||||
var writtenResp *dns.Msg
|
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
|
||||||
_ = forwarder.handleDNSQuery(w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||||
// Test complex overlapping pattern scenarios
|
// Test complex overlapping pattern scenarios
|
||||||
mockFirewall := &MockFirewall{}
|
mockFirewall := &MockFirewall{}
|
||||||
|
|||||||
@@ -105,10 +105,6 @@ type MockWGIface struct {
|
|||||||
LastActivitiesFunc func() map[string]monotime.Time
|
LastActivitiesFunc func() map[string]monotime.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) RemoveEndpointAddress(_ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
|
||||||
return nil, fmt.Errorf("not implemented")
|
return nil, fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ type wgIfaceBase interface {
|
|||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemoveEndpointAddress(key string) error
|
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
|
|||||||
@@ -171,9 +171,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
|
|
||||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||||
|
|
||||||
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
|
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||||
if !isForceRelayed() {
|
if !isForceRelayed() {
|
||||||
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
|
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
conn.guard = guard.NewGuard(conn.Log, conn.isConnectedOnAllWay, conn.config.Timeout, conn.srWatcher)
|
||||||
@@ -430,9 +430,6 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
} else {
|
} else {
|
||||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
|
||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
changed := conn.statusICE.Get() != worker.StatusDisconnected
|
||||||
@@ -526,9 +523,6 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
if conn.currentConnPriority == conntype.Relay {
|
if conn.currentConnPriority == conntype.Relay {
|
||||||
conn.Log.Debugf("clean up WireGuard config")
|
conn.Log.Debugf("clean up WireGuard config")
|
||||||
conn.currentConnPriority = conntype.None
|
conn.currentConnPriority = conntype.None
|
||||||
if err := conn.config.WgConfig.WgInterface.RemoveEndpointAddress(conn.config.WgConfig.RemoteKey); err != nil {
|
|
||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.wgProxyRelay != nil {
|
if conn.wgProxyRelay != nil {
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onNewOfferChan := make(chan struct{})
|
onNewOffeChan := make(chan struct{})
|
||||||
|
|
||||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||||
onNewOfferChan <- struct{}{}
|
onNewOffeChan <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.OnRemoteOffer(OfferAnswer{
|
conn.OnRemoteOffer(OfferAnswer{
|
||||||
@@ -98,7 +98,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-onNewOfferChan:
|
case <-onNewOffeChan:
|
||||||
// success
|
// success
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("expected to receive a new offer notification, but timed out")
|
t.Error("expected to receive a new offer notification, but timed out")
|
||||||
@@ -118,10 +118,10 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
onNewOfferChan := make(chan struct{})
|
onNewOffeChan := make(chan struct{})
|
||||||
|
|
||||||
conn.handshaker.AddRelayListener(func(remoteOfferAnswer *OfferAnswer) {
|
conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
|
||||||
onNewOfferChan <- struct{}{}
|
onNewOffeChan <- struct{}{}
|
||||||
})
|
})
|
||||||
|
|
||||||
conn.OnRemoteAnswer(OfferAnswer{
|
conn.OnRemoteAnswer(OfferAnswer{
|
||||||
@@ -136,7 +136,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-onNewOfferChan:
|
case <-onNewOffeChan:
|
||||||
// success
|
// success
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("expected to receive a new offer notification, but timed out")
|
t.Error("expected to receive a new offer notification, but timed out")
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
package guard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
envICEMonitorPeriod = "NB_ICE_MONITOR_PERIOD"
|
|
||||||
)
|
|
||||||
|
|
||||||
func GetICEMonitorPeriod() time.Duration {
|
|
||||||
if envVal := os.Getenv(envICEMonitorPeriod); envVal != "" {
|
|
||||||
if seconds, err := strconv.Atoi(envVal); err == nil && seconds > 0 {
|
|
||||||
return time.Duration(seconds) * time.Second
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return defaultCandidatesMonitorPeriod
|
|
||||||
}
|
|
||||||
@@ -16,8 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultCandidatesMonitorPeriod = 5 * time.Minute
|
candidatesMonitorPeriod = 5 * time.Minute
|
||||||
candidateGatheringTimeout = 5 * time.Second
|
candidateGatheringTimeout = 5 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type ICEMonitor struct {
|
type ICEMonitor struct {
|
||||||
@@ -25,19 +25,16 @@ type ICEMonitor struct {
|
|||||||
|
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
iceConfig icemaker.Config
|
iceConfig icemaker.Config
|
||||||
tickerPeriod time.Duration
|
|
||||||
|
|
||||||
currentCandidatesAddress []string
|
currentCandidatesAddress []string
|
||||||
candidatesMu sync.Mutex
|
candidatesMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config, period time.Duration) *ICEMonitor {
|
func NewICEMonitor(iFaceDiscover stdnet.ExternalIFaceDiscover, config icemaker.Config) *ICEMonitor {
|
||||||
log.Debugf("prepare ICE monitor with period: %s", period)
|
|
||||||
cm := &ICEMonitor{
|
cm := &ICEMonitor{
|
||||||
ReconnectCh: make(chan struct{}, 1),
|
ReconnectCh: make(chan struct{}, 1),
|
||||||
iFaceDiscover: iFaceDiscover,
|
iFaceDiscover: iFaceDiscover,
|
||||||
iceConfig: config,
|
iceConfig: config,
|
||||||
tickerPeriod: period,
|
|
||||||
}
|
}
|
||||||
return cm
|
return cm
|
||||||
}
|
}
|
||||||
@@ -49,12 +46,7 @@ func (cm *ICEMonitor) Start(ctx context.Context, onChanged func()) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initial check to populate the candidates for later comparison
|
ticker := time.NewTicker(candidatesMonitorPeriod)
|
||||||
if _, err := cm.handleCandidateTick(ctx, ufrag, pwd); err != nil {
|
|
||||||
log.Warnf("Failed to check initial ICE candidates: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ticker := time.NewTicker(cm.tickerPeriod)
|
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (w *SRWatcher) Start() {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
w.cancelIceMonitor = cancel
|
w.cancelIceMonitor = cancel
|
||||||
|
|
||||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig)
|
||||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
go iceMonitor.Start(ctx, w.onICEChanged)
|
||||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||||
|
|||||||
@@ -44,19 +44,13 @@ type OfferAnswer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
ice *WorkerICE
|
ice *WorkerICE
|
||||||
relay *WorkerRelay
|
relay *WorkerRelay
|
||||||
// relayListener is not blocking because the listener is using a goroutine to process the messages
|
onNewOfferListeners []*OfferListener
|
||||||
// and it will only keep the latest message if multiple offers are received in a short time
|
|
||||||
// this is to avoid blocking the handshaker if the listener is doing some heavy processing
|
|
||||||
// and also to avoid processing old offers if multiple offers are received in a short time
|
|
||||||
// the listener will always process the latest offer
|
|
||||||
relayListener *AsyncOfferListener
|
|
||||||
iceListener func(remoteOfferAnswer *OfferAnswer)
|
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
@@ -76,39 +70,28 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
h.relayListener = NewAsyncOfferListener(offer)
|
l := NewOfferListener(offer)
|
||||||
}
|
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||||
|
|
||||||
func (h *Handshaker) AddICEListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
|
||||||
h.iceListener = offer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) Listen(ctx context.Context) {
|
func (h *Handshaker) Listen(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
// received confirmation from the remote peer -> ready to proceed
|
||||||
if h.relayListener != nil {
|
|
||||||
h.relayListener.Notify(&remoteOfferAnswer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if h.iceListener != nil {
|
|
||||||
h.iceListener(&remoteOfferAnswer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := h.sendAnswer(); err != nil {
|
if err := h.sendAnswer(); err != nil {
|
||||||
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
h.log.Errorf("failed to send remote offer confirmation: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
for _, listener := range h.onNewOfferListeners {
|
||||||
|
listener.Notify(&remoteOfferAnswer)
|
||||||
|
}
|
||||||
|
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
if h.relayListener != nil {
|
for _, listener := range h.onNewOfferListeners {
|
||||||
h.relayListener.Notify(&remoteOfferAnswer)
|
listener.Notify(&remoteOfferAnswer)
|
||||||
}
|
|
||||||
|
|
||||||
if h.iceListener != nil {
|
|
||||||
h.iceListener(&remoteOfferAnswer)
|
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
h.log.Infof("stop listening for remote offers and answers")
|
h.log.Infof("stop listening for remote offers and answers")
|
||||||
|
|||||||
@@ -13,20 +13,20 @@ func (oa *OfferAnswer) SessionIDString() string {
|
|||||||
return oa.SessionID.String()
|
return oa.SessionID.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
type AsyncOfferListener struct {
|
type OfferListener struct {
|
||||||
fn callbackFunc
|
fn callbackFunc
|
||||||
running bool
|
running bool
|
||||||
latest *OfferAnswer
|
latest *OfferAnswer
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAsyncOfferListener(fn callbackFunc) *AsyncOfferListener {
|
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||||
return &AsyncOfferListener{
|
return &OfferListener{
|
||||||
fn: fn,
|
fn: fn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *AsyncOfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||||
o.mu.Lock()
|
o.mu.Lock()
|
||||||
defer o.mu.Unlock()
|
defer o.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func Test_newOfferListener(t *testing.T) {
|
|||||||
runChan <- struct{}{}
|
runChan <- struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
hl := NewAsyncOfferListener(longRunningFn)
|
hl := NewOfferListener(longRunningFn)
|
||||||
|
|
||||||
hl.Notify(dummyOfferAnswer)
|
hl.Notify(dummyOfferAnswer)
|
||||||
hl.Notify(dummyOfferAnswer)
|
hl.Notify(dummyOfferAnswer)
|
||||||
|
|||||||
@@ -18,5 +18,4 @@ type WGIface interface {
|
|||||||
GetStats() (map[string]configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
RemoveEndpointAddress(key string) error
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -92,16 +92,23 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
|
|||||||
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
||||||
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
defer w.muxAgent.Unlock()
|
|
||||||
|
|
||||||
if w.agent != nil || w.agentConnecting {
|
if w.agentConnecting {
|
||||||
|
w.log.Debugf("agent connection is in progress, skipping the offer")
|
||||||
|
w.muxAgent.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.agent != nil {
|
||||||
// backward compatibility with old clients that do not send session ID
|
// backward compatibility with old clients that do not send session ID
|
||||||
if remoteOfferAnswer.SessionID == nil {
|
if remoteOfferAnswer.SessionID == nil {
|
||||||
w.log.Debugf("agent already exists, skipping the offer")
|
w.log.Debugf("agent already exists, skipping the offer")
|
||||||
|
w.muxAgent.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
|
||||||
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
|
||||||
|
w.muxAgent.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.log.Debugf("agent already exists, recreate the connection")
|
w.log.Debugf("agent already exists, recreate the connection")
|
||||||
@@ -109,12 +116,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
if err := w.agent.Close(); err != nil {
|
if err := w.agent.Close(); err != nil {
|
||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionID, err := NewICESessionID()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed to create new session ID: %s", err)
|
|
||||||
}
|
|
||||||
w.sessionID = sessionID
|
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,23 +126,18 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
preferredCandidateTypes = icemaker.CandidateTypes()
|
preferredCandidateTypes = icemaker.CandidateTypes()
|
||||||
}
|
}
|
||||||
|
|
||||||
if remoteOfferAnswer.SessionID != nil {
|
w.log.Debugf("recreate ICE agent")
|
||||||
w.log.Debugf("recreate ICE agent: %s / %s", w.sessionID, *remoteOfferAnswer.SessionID)
|
|
||||||
}
|
|
||||||
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
dialerCtx, dialerCancel := context.WithCancel(w.ctx)
|
||||||
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
w.log.Errorf("failed to recreate ICE Agent: %s", err)
|
||||||
|
w.muxAgent.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.agent = agent
|
w.agent = agent
|
||||||
w.agentDialerCancel = dialerCancel
|
w.agentDialerCancel = dialerCancel
|
||||||
w.agentConnecting = true
|
w.agentConnecting = true
|
||||||
if remoteOfferAnswer.SessionID != nil {
|
w.muxAgent.Unlock()
|
||||||
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
|
||||||
} else {
|
|
||||||
w.remoteSessionID = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
go w.connect(dialerCtx, agent, remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
@@ -297,6 +293,9 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent
|
|||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.lastSuccess = time.Now()
|
w.lastSuccess = time.Now()
|
||||||
|
if remoteOfferAnswer.SessionID != nil {
|
||||||
|
w.remoteSessionID = *remoteOfferAnswer.SessionID
|
||||||
|
}
|
||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
|
|
||||||
// todo: the potential problem is a race between the onConnectionStateChange
|
// todo: the potential problem is a race between the onConnectionStateChange
|
||||||
@@ -310,17 +309,16 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
w.muxAgent.Lock()
|
||||||
|
// todo review does it make sense to generate new session ID all the time when w.agent==agent
|
||||||
|
sessionID, err := NewICESessionID()
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed to create new session ID: %s", err)
|
||||||
|
}
|
||||||
|
w.sessionID = sessionID
|
||||||
|
|
||||||
if w.agent == agent {
|
if w.agent == agent {
|
||||||
// consider to remove from here and move to the OnNewOffer
|
|
||||||
sessionID, err := NewICESessionID()
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed to create new session ID: %s", err)
|
|
||||||
}
|
|
||||||
w.sessionID = sessionID
|
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
w.agentConnecting = false
|
w.agentConnecting = false
|
||||||
w.remoteSessionID = ""
|
|
||||||
}
|
}
|
||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
}
|
}
|
||||||
@@ -397,12 +395,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
w.closeAgent(agent, dialerCancel)
|
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected()
|
||||||
}
|
}
|
||||||
|
w.closeAgent(agent, dialerCancel)
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1354,13 +1354,7 @@ func (s *serviceClient) updateConfig() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
|
||||||
// It also starts a background goroutine that periodically checks if the client is already connected
|
func (s *serviceClient) showLoginURL() {
|
||||||
// and closes the window if so. The goroutine can be cancelled by the returned CancelFunc, and it is
|
|
||||||
// also cancelled when the window is closed.
|
|
||||||
func (s *serviceClient) showLoginURL() context.CancelFunc {
|
|
||||||
|
|
||||||
// create a cancellable context for the background check goroutine
|
|
||||||
ctx, cancel := context.WithCancel(s.ctx)
|
|
||||||
|
|
||||||
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
|
||||||
|
|
||||||
@@ -1369,8 +1363,6 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
|||||||
s.wLoginURL.Resize(fyne.NewSize(400, 200))
|
s.wLoginURL.Resize(fyne.NewSize(400, 200))
|
||||||
s.wLoginURL.SetIcon(resIcon)
|
s.wLoginURL.SetIcon(resIcon)
|
||||||
}
|
}
|
||||||
// ensure goroutine is cancelled when the window is closed
|
|
||||||
s.wLoginURL.SetOnClosed(func() { cancel() })
|
|
||||||
// add a description label
|
// add a description label
|
||||||
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
|
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
|
||||||
|
|
||||||
@@ -1451,39 +1443,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
|||||||
)
|
)
|
||||||
s.wLoginURL.SetContent(container.NewCenter(content))
|
s.wLoginURL.SetContent(container.NewCenter(content))
|
||||||
|
|
||||||
// start a goroutine to check connection status and close the window if connected
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(5 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
conn, err := s.getSrvClient(failFastTimeout)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if status.Status == string(internal.StatusConnected) {
|
|
||||||
if s.wLoginURL != nil {
|
|
||||||
s.wLoginURL.Close()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.wLoginURL.Show()
|
s.wLoginURL.Show()
|
||||||
|
|
||||||
// return cancel func so callers can stop the background goroutine if desired
|
|
||||||
return cancel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(url string) error {
|
func openURL(url string) error {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ services:
|
|||||||
- traefik.enable=true
|
- traefik.enable=true
|
||||||
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
|
- traefik.http.routers.netbird-wsproxy-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/ws-proxy/signal`)
|
||||||
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
- traefik.http.routers.netbird-wsproxy-signal.service=netbird-wsproxy-signal
|
||||||
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=80
|
- traefik.http.services.netbird-wsproxy-signal.loadbalancer.server.port=10000
|
||||||
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
- traefik.http.routers.netbird-signal.rule=Host(`$NETBIRD_DOMAIN`) && PathPrefix(`/signalexchange.SignalExchange/`)
|
||||||
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
- traefik.http.services.netbird-signal.loadbalancer.server.port=10000
|
||||||
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
- traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c
|
||||||
|
|||||||
@@ -621,7 +621,7 @@ renderCaddyfile() {
|
|||||||
# relay
|
# relay
|
||||||
reverse_proxy /relay* relay:80
|
reverse_proxy /relay* relay:80
|
||||||
# Signal
|
# Signal
|
||||||
reverse_proxy /ws-proxy/signal* signal:80
|
reverse_proxy /ws-proxy/signal* signal:10000
|
||||||
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
|
||||||
# Management
|
# Management
|
||||||
reverse_proxy /api/* management:80
|
reverse_proxy /api/* management:80
|
||||||
|
|||||||
@@ -20,10 +20,6 @@ upstream management {
|
|||||||
# insert the grpc+http port of your management container here
|
# insert the grpc+http port of your management container here
|
||||||
server 127.0.0.1:8012;
|
server 127.0.0.1:8012;
|
||||||
}
|
}
|
||||||
upstream relay {
|
|
||||||
# insert the port of your relay container here
|
|
||||||
server 127.0.0.1:33080;
|
|
||||||
}
|
|
||||||
|
|
||||||
server {
|
server {
|
||||||
# HTTP server config
|
# HTTP server config
|
||||||
@@ -59,10 +55,6 @@ server {
|
|||||||
# Proxy Signal wsproxy endpoint
|
# Proxy Signal wsproxy endpoint
|
||||||
location /ws-proxy/signal {
|
location /ws-proxy/signal {
|
||||||
proxy_pass http://signal;
|
proxy_pass http://signal;
|
||||||
proxy_http_version 1.1;
|
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
|
||||||
proxy_set_header Connection "Upgrade";
|
|
||||||
proxy_set_header Host $host;
|
|
||||||
}
|
}
|
||||||
# Proxy Signal
|
# Proxy Signal
|
||||||
location /signalexchange.SignalExchange/ {
|
location /signalexchange.SignalExchange/ {
|
||||||
@@ -79,10 +71,6 @@ server {
|
|||||||
# Proxy Management wsproxy endpoint
|
# Proxy Management wsproxy endpoint
|
||||||
location /ws-proxy/management {
|
location /ws-proxy/management {
|
||||||
proxy_pass http://management;
|
proxy_pass http://management;
|
||||||
proxy_http_version 1.1;
|
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
|
||||||
proxy_set_header Connection "Upgrade";
|
|
||||||
proxy_set_header Host $host;
|
|
||||||
}
|
}
|
||||||
# Proxy Management grpc endpoint
|
# Proxy Management grpc endpoint
|
||||||
location /management.ManagementService/ {
|
location /management.ManagementService/ {
|
||||||
@@ -92,14 +80,6 @@ server {
|
|||||||
grpc_send_timeout 1d;
|
grpc_send_timeout 1d;
|
||||||
grpc_socket_keepalive on;
|
grpc_socket_keepalive on;
|
||||||
}
|
}
|
||||||
# Proxy Relay
|
|
||||||
location /relay {
|
|
||||||
proxy_pass http://relay;
|
|
||||||
proxy_http_version 1.1;
|
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
|
||||||
proxy_set_header Connection "Upgrade";
|
|
||||||
proxy_set_header Host $host;
|
|
||||||
}
|
|
||||||
|
|
||||||
ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem;
|
ssl_certificate /etc/ssl/certs/ssl-cert-snakeoil.pem;
|
||||||
ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem;
|
ssl_certificate_key /etc/ssl/certs/ssl-cert-snakeoil.pem;
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM ubuntu:24.04
|
FROM ubuntu:24.10
|
||||||
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
|
||||||
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
|
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
|
||||||
CMD ["--log-file", "console"]
|
CMD ["--log-file", "console"]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -251,7 +252,7 @@ func updateMgmtConfig(ctx context.Context, path string, config *nbconfig.Config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
func (s *BaseServer) handlerFunc(gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), ManagementLegacyPort), wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
// nolint:gosec
|
// nolint:gosec
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||||
@@ -62,10 +63,10 @@ var (
|
|||||||
Use: "run",
|
Use: "run",
|
||||||
Short: "start NetBird Signal Server daemon",
|
Short: "start NetBird Signal Server daemon",
|
||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
PreRunE: func(cmd *cobra.Command, args []string) error {
|
PreRun: func(cmd *cobra.Command, args []string) {
|
||||||
err := util.InitLog(logLevel, logFile)
|
err := util.InitLog(logLevel, logFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed initializing log: %w", err)
|
log.Fatalf("failed initializing log %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
@@ -86,8 +87,6 @@ var (
|
|||||||
signalPort = 80
|
signalPort = 80
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
@@ -255,7 +254,7 @@ func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler h
|
|||||||
}
|
}
|
||||||
|
|
||||||
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
|
func grpcHandlerFunc(grpcServer *grpc.Server, meter metric.Meter) http.Handler {
|
||||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
wsProxy := wsproxyserver.New(netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), legacyGRPCPort), wsproxyserver.WithOTelMeter(meter))
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -2,41 +2,42 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coder/websocket"
|
"github.com/coder/websocket"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util/wsproxy"
|
"github.com/netbirdio/netbird/util/wsproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
bufferSize = 32 * 1024
|
dialTimeout = 10 * time.Second
|
||||||
ioTimeout = 5 * time.Second
|
bufferSize = 32 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config contains the configuration for the WebSocket proxy.
|
// Config contains the configuration for the WebSocket proxy.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Handler http.Handler
|
LocalGRPCAddr netip.AddrPort
|
||||||
Path string
|
Path string
|
||||||
MetricsRecorder MetricsRecorder
|
MetricsRecorder MetricsRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// Proxy handles WebSocket to gRPC handler proxying.
|
// Proxy handles WebSocket to TCP proxying for gRPC connections.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
config Config
|
config Config
|
||||||
metrics MetricsRecorder
|
metrics MetricsRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new WebSocket proxy instance with optional configuration
|
// New creates a new WebSocket proxy instance with optional configuration
|
||||||
func New(handler http.Handler, opts ...Option) *Proxy {
|
func New(localGRPCAddr netip.AddrPort, opts ...Option) *Proxy {
|
||||||
config := Config{
|
config := Config{
|
||||||
Handler: handler,
|
LocalGRPCAddr: localGRPCAddr,
|
||||||
Path: wsproxy.ProxyPath,
|
Path: wsproxy.ProxyPath,
|
||||||
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
|
MetricsRecorder: NoOpMetricsRecorder{}, // Default to no-op
|
||||||
}
|
}
|
||||||
@@ -62,7 +63,7 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
p.metrics.RecordConnection(ctx)
|
p.metrics.RecordConnection(ctx)
|
||||||
defer p.metrics.RecordDisconnection(ctx)
|
defer p.metrics.RecordDisconnection(ctx)
|
||||||
|
|
||||||
log.Debugf("WebSocket proxy handling connection from %s, forwarding to internal gRPC handler", r.RemoteAddr)
|
log.Debugf("WebSocket proxy handling connection from %s, forwarding to %s", r.RemoteAddr, p.config.LocalGRPCAddr)
|
||||||
acceptOptions := &websocket.AcceptOptions{
|
acceptOptions := &websocket.AcceptOptions{
|
||||||
OriginPatterns: []string{"*"},
|
OriginPatterns: []string{"*"},
|
||||||
}
|
}
|
||||||
@@ -74,41 +75,71 @@ func (p *Proxy) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = wsConn.Close(websocket.StatusNormalClosure, "")
|
if err := wsConn.Close(websocket.StatusNormalClosure, ""); err != nil {
|
||||||
|
log.Debugf("Failed to close WebSocket: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientConn, serverConn := net.Pipe()
|
log.Debugf("WebSocket proxy attempting to connect to local gRPC at %s", p.config.LocalGRPCAddr)
|
||||||
|
tcpConn, err := net.DialTimeout("tcp", p.config.LocalGRPCAddr.String(), dialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
p.metrics.RecordError(ctx, "tcp_dial_failed")
|
||||||
|
log.Warnf("Failed to connect to local gRPC server at %s: %v", p.config.LocalGRPCAddr, err)
|
||||||
|
if err := wsConn.Close(websocket.StatusInternalError, "Backend unavailable"); err != nil {
|
||||||
|
log.Debugf("Failed to close WebSocket after connection failure: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = clientConn.Close()
|
if err := tcpConn.Close(); err != nil {
|
||||||
_ = serverConn.Close()
|
log.Debugf("Failed to close TCP connection: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debugf("WebSocket proxy established: %s -> gRPC handler", r.RemoteAddr)
|
log.Debugf("WebSocket proxy established: client %s -> local gRPC %s", r.RemoteAddr, p.config.LocalGRPCAddr)
|
||||||
|
|
||||||
go func() {
|
p.proxyData(ctx, wsConn, tcpConn)
|
||||||
(&http2.Server{}).ServeConn(serverConn, &http2.ServeConnOpts{
|
|
||||||
Context: ctx,
|
|
||||||
Handler: p.config.Handler,
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
|
|
||||||
p.proxyData(ctx, wsConn, clientConn, r.RemoteAddr)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
|
func (p *Proxy) proxyData(ctx context.Context, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||||
proxyCtx, cancel := context.WithCancel(ctx)
|
proxyCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
|
||||||
go p.wsToPipe(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
|
go p.wsToTCP(proxyCtx, cancel, &wg, wsConn, tcpConn)
|
||||||
go p.pipeToWS(proxyCtx, cancel, &wg, wsConn, pipeConn, clientAddr)
|
go p.tcpToWS(proxyCtx, cancel, &wg, wsConn, tcpConn)
|
||||||
|
|
||||||
wg.Wait()
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
log.Tracef("Proxy data transfer completed, both goroutines terminated")
|
||||||
|
case <-proxyCtx.Done():
|
||||||
|
log.Tracef("Proxy data transfer cancelled, forcing connection closure")
|
||||||
|
|
||||||
|
if err := wsConn.Close(websocket.StatusGoingAway, "proxy cancelled"); err != nil {
|
||||||
|
log.Tracef("Error closing WebSocket during cancellation: %v", err)
|
||||||
|
}
|
||||||
|
if err := tcpConn.Close(); err != nil {
|
||||||
|
log.Tracef("Error closing TCP connection during cancellation: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
log.Tracef("Goroutines terminated after forced connection closure")
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
log.Tracef("Goroutines did not terminate within timeout after connection closure")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
|
func (p *Proxy) wsToTCP(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -117,73 +148,80 @@ func (p *Proxy) wsToPipe(ctx context.Context, cancel context.CancelFunc, wg *syn
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
case ctx.Err() != nil:
|
case ctx.Err() != nil:
|
||||||
log.Debugf("WebSocket from %s terminating due to context cancellation", clientAddr)
|
log.Debugf("wsToTCP goroutine terminating due to context cancellation")
|
||||||
case websocket.CloseStatus(err) != -1:
|
case websocket.CloseStatus(err) == websocket.StatusNormalClosure:
|
||||||
log.Debugf("WebSocket from %s disconnected", clientAddr)
|
log.Debugf("WebSocket closed normally")
|
||||||
default:
|
default:
|
||||||
p.metrics.RecordError(ctx, "websocket_read_error")
|
p.metrics.RecordError(ctx, "websocket_read_error")
|
||||||
log.Debugf("WebSocket read error from %s: %v", clientAddr, err)
|
log.Errorf("WebSocket read error: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgType != websocket.MessageBinary {
|
if msgType != websocket.MessageBinary {
|
||||||
log.Warnf("Unexpected WebSocket message type from %s: %v", clientAddr, msgType)
|
log.Warnf("Unexpected WebSocket message type: %v", msgType)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
log.Tracef("wsToPipe goroutine terminating due to context cancellation before pipe write")
|
log.Tracef("wsToTCP goroutine terminating due to context cancellation before TCP write")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pipeConn.SetWriteDeadline(time.Now().Add(ioTimeout)); err != nil {
|
if err := tcpConn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
log.Debugf("Failed to set pipe write deadline: %v", err)
|
log.Debugf("Failed to set TCP write deadline: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := pipeConn.Write(data)
|
n, err := tcpConn.Write(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.metrics.RecordError(ctx, "pipe_write_error")
|
p.metrics.RecordError(ctx, "tcp_write_error")
|
||||||
log.Warnf("Pipe write error for %s: %v", clientAddr, err)
|
log.Errorf("TCP write error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.metrics.RecordBytesTransferred(ctx, "ws_to_grpc", int64(n))
|
p.metrics.RecordBytesTransferred(ctx, "ws_to_tcp", int64(n))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) pipeToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, pipeConn net.Conn, clientAddr string) {
|
func (p *Proxy) tcpToWS(ctx context.Context, cancel context.CancelFunc, wg *sync.WaitGroup, wsConn *websocket.Conn, tcpConn net.Conn) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
for {
|
for {
|
||||||
n, err := pipeConn.Read(buf)
|
if err := tcpConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
log.Debugf("Failed to set TCP read deadline: %v", err)
|
||||||
|
}
|
||||||
|
n, err := tcpConn.Read(buf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
log.Tracef("pipeToWS goroutine terminating due to context cancellation")
|
log.Tracef("tcpToWS goroutine terminating due to context cancellation")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if err != io.EOF {
|
if err != io.EOF {
|
||||||
log.Debugf("Pipe read error for %s: %v", clientAddr, err)
|
log.Errorf("TCP read error: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
log.Tracef("pipeToWS goroutine terminating due to context cancellation before WebSocket write")
|
log.Tracef("tcpToWS goroutine terminating due to context cancellation before WebSocket write")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if n > 0 {
|
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
|
||||||
if err := wsConn.Write(ctx, websocket.MessageBinary, buf[:n]); err != nil {
|
p.metrics.RecordError(ctx, "websocket_write_error")
|
||||||
p.metrics.RecordError(ctx, "websocket_write_error")
|
log.Errorf("WebSocket write error: %v", err)
|
||||||
log.Warnf("WebSocket write error for %s: %v", clientAddr, err)
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p.metrics.RecordBytesTransferred(ctx, "grpc_to_ws", int64(n))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.metrics.RecordBytesTransferred(ctx, "tcp_to_ws", int64(n))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,9 @@
|
|||||||
package version
|
package version
|
||||||
|
|
||||||
import (
|
import "golang.org/x/sys/windows/registry"
|
||||||
"golang.org/x/sys/windows/registry"
|
|
||||||
"runtime"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
urlWinExe = "https://pkgs.netbird.io/windows/x64"
|
urlWinExe = "https://pkgs.netbird.io/windows/x64"
|
||||||
urlWinExeArm = "https://pkgs.netbird.io/windows/arm64"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
|
var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Netbird"
|
||||||
@@ -15,14 +11,9 @@ var regKeyAppPath = "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\App Paths\\Ne
|
|||||||
// DownloadUrl return with the proper download link
|
// DownloadUrl return with the proper download link
|
||||||
func DownloadUrl() string {
|
func DownloadUrl() string {
|
||||||
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
|
_, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyAppPath, registry.QUERY_VALUE)
|
||||||
if err != nil {
|
if err == nil {
|
||||||
|
return urlWinExe
|
||||||
|
} else {
|
||||||
return downloadURL
|
return downloadURL
|
||||||
}
|
}
|
||||||
|
|
||||||
url := urlWinExe
|
|
||||||
if runtime.GOARCH == "arm64" {
|
|
||||||
url = urlWinExeArm
|
|
||||||
}
|
|
||||||
|
|
||||||
return url
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user