Compare commits

...

11 Commits

Author SHA1 Message Date
mlsmaycon
bfeb60fbb5 Create a system proxy change after receiving a network map
This is experimental and needs more test.

the purpose of this change is to validate that a TLS connection stuck using old routes because of keepalive settings on the remote webserver are reset once netbird receives a network map
2026-02-01 10:23:25 +01:00
mlsmaycon
ea41cf2d2c Create a system proxy change after receiving a network map
This is experimental and needs more test.

the purpose of this change is to validate that a TLS connection stuck using old routes because of keepalive settings on the remote webserver are reset once netbird receives a network map
2026-02-01 10:21:51 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
Zoltan Papp
5333e55a81 Fix WG watcher missing initial handshake (#5213)
Start the WireGuard watcher before configuring the WG endpoint to ensure it captures the initial handshake timestamp.

Previously, the watcher was started after endpoint configuration, causing it to miss the handshake that occurred during setup.
2026-01-29 16:58:10 +01:00
Viktor Liu
81c11df103 [management] Streamline domain validation (#5211) 2026-01-29 13:51:44 +01:00
Viktor Liu
f74bc48d16 [Client] Stop NetBird on firewall init failure (#5208) 2026-01-29 11:05:06 +01:00
Vlad
0169e4540f [management] fix skip of ephemeral peers on deletion (#5206) 2026-01-29 10:58:45 +01:00
Vlad
cead3f38ee [management] fix ephemeral peers being not removed (#5203) 2026-01-28 18:24:12 +01:00
Zoltan Papp
b55262d4a2 [client] Refactor/optimise raw socket headers (#5174)
Pre-create and reuse packet headers to eliminate per-packet allocations.
2026-01-28 15:06:59 +01:00
Zoltan Papp
2248ff392f Remove redundant square bracket trimming in USP endpoint parsing (#5197) 2026-01-27 20:10:59 +01:00
34 changed files with 1296 additions and 204 deletions

View File

@@ -69,6 +69,8 @@ type Options struct {
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
}
// validateCredentials checks that exactly one credential type is provided
@@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)

View File

@@ -558,7 +558,7 @@ func parseStatus(deviceName, ipcStr string) (*Stats, error) {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
host, portStr, err := net.SplitHostPort(val)
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue

View File

@@ -8,8 +8,6 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
@@ -26,16 +24,6 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
@@ -253,63 +241,3 @@ generatePort:
}
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var dstIP net.IP
var rawConn net.PacketConn
if endpointAddr.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
dstIP = localHostNetIPv4
rawConn = p.rawConnIPv4
} else {
// IPv6 path
if p.rawConnIPv6 == nil {
return fmt.Errorf("IPv6 raw socket not available")
}
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpointAddr.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
dstIP = localHostNetIPv6
rawConn = p.rawConnIPv6
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return fmt.Errorf("set network layer for checksum: %w", err)
}
layerBuffer := gopacket.NewSerializeBuffer()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}

View File

@@ -10,12 +10,89 @@ import (
"net"
"sync"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
var (
errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available")
errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available")
localHostNetIPv4 = net.ParseIP("127.0.0.1")
localHostNetIPv6 = net.ParseIP("::1")
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
)
// PacketHeaders holds pre-created headers and buffers for efficient packet sending
type PacketHeaders struct {
ipH gopacket.SerializableLayer
udpH *layers.UDP
layerBuffer gopacket.SerializeBuffer
localHostAddr net.IP
isIPv4 bool
}
func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) {
var ipH gopacket.SerializableLayer
var networkLayer gopacket.NetworkLayer
var localHostAddr net.IP
var isIPv4 bool
// Check if source address is IPv4 or IPv6
if endpoint.IP.To4() != nil {
// IPv4 path
ipv4 := &layers.IPv4{
DstIP: localHostNetIPv4,
SrcIP: endpoint.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
ipH = ipv4
networkLayer = ipv4
localHostAddr = localHostNetIPv4
isIPv4 = true
} else {
// IPv6 path
ipv6 := &layers.IPv6{
DstIP: localHostNetIPv6,
SrcIP: endpoint.IP,
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
}
ipH = ipv6
networkLayer = ipv6
localHostAddr = localHostNetIPv6
isIPv4 = false
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(endpoint.Port),
DstPort: layers.UDPPort(localWGListenPort),
}
if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil {
return nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return &PacketHeaders{
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
localHostAddr: localHostAddr,
isIPv4: isIPv4,
}, nil
}
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
wgeBPFProxy *WGEBPFProxy
@@ -24,8 +101,10 @@ type ProxyWrapper struct {
ctx context.Context
cancel context.CancelFunc
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
headers *PacketHeaders
headerCurrentUsed *PacketHeaders
rawConn net.PacketConn
paused bool
pausedCond *sync.Cond
@@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr)
if err != nil {
return fmt.Errorf("create packet sender: %w", err)
}
// Check if required raw connection is available
if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
return errIPv6ConnNotAvailable
}
if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
return errIPv4ConnNotAvailable
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgRelayedEndpointAddr = addr
return err
p.headers = headers
p.rawConn = p.selectRawConn(headers)
return nil
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
@@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
p.headerCurrentUsed = p.headers
p.rawConn = p.selectRawConn(p.headerCurrentUsed)
if !p.isStarted {
p.isStarted = true
@@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
log.Errorf("failed to start package redirection, endpoint is nil")
return
}
header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create packet headers: %s", err)
return
}
// Check if required raw connection is available
if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil {
log.Error(errIPv6ConnNotAvailable)
return
}
if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil {
log.Error(errIPv4ConnNotAvailable)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.headerCurrentUsed = header
p.rawConn = p.selectRawConn(header)
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
@@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
p.pausedCond.Wait()
}
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
err = p.sendPkg(buf[:n], p.headerCurrentUsed)
p.pausedCond.L.Unlock()
if err != nil {
@@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
return n, nil
}
func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error {
defer func() {
if err := header.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn {
if header.isIPv4 {
return p.wgeBPFProxy.rawConnIPv4
}
return p.wgeBPFProxy.rawConnIPv6
}

View File

@@ -112,6 +112,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
matchSubdomains: false,
shouldMatch: false,
},
{
name: "single letter TLD exact match",
handlerDomain: "example.x.",
queryDomain: "example.x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single letter TLD subdomain match",
handlerDomain: "example.x.",
queryDomain: "sub.example.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "single letter TLD wildcard match",
handlerDomain: "*.example.x.",
queryDomain: "sub.example.x.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "two letter domain labels",
handlerDomain: "a.b.",
queryDomain: "a.b.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain",
handlerDomain: "x.",
queryDomain: "x.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "single character domain with subdomain match",
handlerDomain: "x.",
queryDomain: "sub.x.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
}
for _, tt := range tests {

View File

@@ -9,8 +9,10 @@ import (
"io"
"net/netip"
"os/exec"
"slices"
"strconv"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -38,6 +40,9 @@ const (
type systemConfigurator struct {
createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
}
func newHostManager() (*systemConfigurator, error) {
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
}
var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false
inServerAddressesArray := false
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray {
address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
dnsSettings.ServerIP = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
ip = ip.Unmap()
serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
}
}
}
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port
dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil
}
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true)
if err != nil {

View File

@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config)
}
// registerFallback registers original nameservers as low-priority fallback handlers
// registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok {
@@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return
}

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
}
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -44,6 +44,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/proxy"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -140,6 +141,11 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config
LogPath string
// ProxyConfig contains system proxy settings for macOS
ProxyEnabled bool
ProxyHost string
ProxyPort int
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -223,6 +229,9 @@ type Engine struct {
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
// proxyManager manages system-wide browser proxy settings on macOS
proxyManager *proxy.Manager
}
// Peer is an instance of the Connection Peer
@@ -313,6 +322,12 @@ func (e *Engine) Stop() error {
e.updateManager.Stop()
}
if e.proxyManager != nil {
if err := e.proxyManager.DisableWebProxy(); err != nil {
log.Warnf("failed to disable system proxy: %v", err)
}
}
log.Info("cleaning up status recorder states")
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
@@ -448,6 +463,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
}
e.stateManager.Start()
// Initialize proxy manager and register state for cleanup
proxy.RegisterState(e.stateManager)
e.proxyManager = proxy.NewManager(e.stateManager)
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
@@ -573,9 +592,11 @@ func (e *Engine) createFirewall() error {
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
if err != nil {
return fmt.Errorf("create firewall manager: %w", err)
}
if e.firewall == nil {
return fmt.Errorf("create firewall manager: received nil manager")
}
if err := e.initFirewall(); err != nil {
@@ -1310,6 +1331,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// If no server of a server group responds this will disable the respective handler and retry later.
e.dnsServer.ProbeAvailability()
// Update system proxy state based on routes after network map is fully applied
e.updateSystemProxy(clientRoutes)
return nil
}
@@ -2301,6 +2325,26 @@ func createFile(path string) error {
return file.Close()
}
// updateSystemProxy triggers a proxy enable/disable cycle after the network map is updated.
func (e *Engine) updateSystemProxy(clientRoutes route.HAMap) {
if runtime.GOOS != "darwin" || e.proxyManager == nil {
log.Errorf("not updating proxy")
return
}
if err := e.proxyManager.EnableWebProxy(e.config.ProxyHost, e.config.ProxyPort); err != nil {
log.Errorf("enable system proxy: %v", err)
return
}
log.Error("system proxy enabled after network map update")
if err := e.proxyManager.DisableWebProxy(); err != nil {
log.Errorf("disable system proxy: %v", err)
return
}
log.Error("system proxy disabled after network map update")
}
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {

View File

@@ -14,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()

View File

@@ -390,6 +390,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
}
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
conn.enableWgWatcherIfNeeded()
presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey)
if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
@@ -402,8 +404,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.enableWgWatcherIfNeeded()
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
@@ -501,6 +501,9 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgProxy.Work()
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
conn.enableWgWatcherIfNeeded()
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
@@ -509,8 +512,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return
}
conn.enableWgWatcherIfNeeded()
wgConfigWorkaround()
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay

View File

@@ -0,0 +1,262 @@
//go:build darwin && !ios
package proxy
import (
"fmt"
"os/exec"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const networksetupPath = "/usr/sbin/networksetup"
// Manager handles system-wide proxy configuration on macOS.
type Manager struct {
mu sync.Mutex
stateManager *statemanager.Manager
modifiedServices []string
enabled bool
}
// NewManager creates a new proxy manager.
func NewManager(stateManager *statemanager.Manager) *Manager {
return &Manager{
stateManager: stateManager,
}
}
// GetActiveNetworkServices returns the list of active network services.
func GetActiveNetworkServices() ([]string, error) {
cmd := exec.Command(networksetupPath, "-listallnetworkservices")
out, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("list network services: %w", err)
}
lines := strings.Split(string(out), "\n")
var services []string
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "*") || strings.Contains(line, "asterisk") {
continue
}
services = append(services, line)
}
return services, nil
}
// EnableWebProxy enables web proxy for all active network services.
func (m *Manager) EnableWebProxy(host string, port int) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.enabled {
log.Debug("web proxy already enabled")
return nil
}
services, err := GetActiveNetworkServices()
if err != nil {
return err
}
var modifiedServices []string
for _, service := range services {
if err := m.enableProxyForService(service, host, port); err != nil {
log.Warnf("enable proxy for %s: %v", service, err)
continue
}
modifiedServices = append(modifiedServices, service)
}
m.modifiedServices = modifiedServices
m.enabled = true
m.updateState()
log.Infof("enabled web proxy on %d services -> %s:%d", len(modifiedServices), host, port)
return nil
}
func (m *Manager) enableProxyForService(service, host string, port int) error {
portStr := fmt.Sprintf("%d", port)
// Set web proxy (HTTP)
cmd := exec.Command(networksetupPath, "-setwebproxy", service, host, portStr)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("set web proxy: %w, output: %s", err, out)
}
// Enable web proxy
cmd = exec.Command(networksetupPath, "-setwebproxystate", service, "on")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("enable web proxy state: %w, output: %s", err, out)
}
// Set secure web proxy (HTTPS)
cmd = exec.Command(networksetupPath, "-setsecurewebproxy", service, host, portStr)
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("set secure web proxy: %w, output: %s", err, out)
}
// Enable secure web proxy
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "on")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("enable secure web proxy state: %w, output: %s", err, out)
}
log.Debugf("enabled proxy for service %s", service)
return nil
}
// DisableWebProxy disables web proxy for all modified network services.
func (m *Manager) DisableWebProxy() error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.enabled {
log.Debug("web proxy already disabled")
return nil
}
services := m.modifiedServices
if len(services) == 0 {
services, _ = GetActiveNetworkServices()
}
for _, service := range services {
if err := m.disableProxyForService(service); err != nil {
log.Warnf("disable proxy for %s: %v", service, err)
}
}
m.modifiedServices = nil
m.enabled = false
m.updateState()
log.Info("disabled web proxy")
return nil
}
func (m *Manager) disableProxyForService(service string) error {
// Disable web proxy (HTTP)
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("disable web proxy: %w, output: %s", err, out)
}
// Disable secure web proxy (HTTPS)
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("disable secure web proxy: %w, output: %s", err, out)
}
log.Debugf("disabled proxy for service %s", service)
return nil
}
// SetAutoproxyURL sets the automatic proxy configuration URL (PAC file).
func (m *Manager) SetAutoproxyURL(pacURL string) error {
m.mu.Lock()
defer m.mu.Unlock()
services, err := GetActiveNetworkServices()
if err != nil {
return err
}
var modifiedServices []string
for _, service := range services {
cmd := exec.Command(networksetupPath, "-setautoproxyurl", service, pacURL)
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("set autoproxy for %s: %v, output: %s", service, err, out)
continue
}
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "on")
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("enable autoproxy for %s: %v, output: %s", service, err, out)
continue
}
modifiedServices = append(modifiedServices, service)
log.Debugf("set autoproxy URL for %s -> %s", service, pacURL)
}
m.modifiedServices = modifiedServices
m.enabled = true
m.updateState()
return nil
}
// DisableAutoproxy disables automatic proxy configuration.
func (m *Manager) DisableAutoproxy() error {
m.mu.Lock()
defer m.mu.Unlock()
services := m.modifiedServices
if len(services) == 0 {
services, _ = GetActiveNetworkServices()
}
for _, service := range services {
cmd := exec.Command(networksetupPath, "-setautoproxystate", service, "off")
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("disable autoproxy for %s: %v, output: %s", service, err, out)
}
}
m.modifiedServices = nil
m.enabled = false
m.updateState()
return nil
}
// IsEnabled returns whether the proxy is currently enabled.
func (m *Manager) IsEnabled() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.enabled
}
// Restore restores proxy settings from a previous state.
func (m *Manager) Restore(services []string) error {
m.mu.Lock()
defer m.mu.Unlock()
for _, service := range services {
if err := m.disableProxyForService(service); err != nil {
log.Warnf("restore proxy for %s: %v", service, err)
}
}
m.modifiedServices = nil
m.enabled = false
return nil
}
func (m *Manager) updateState() {
if m.stateManager == nil {
return
}
if m.enabled && len(m.modifiedServices) > 0 {
state := &ShutdownState{
ModifiedServices: m.modifiedServices,
}
if err := m.stateManager.UpdateState(state); err != nil {
log.Errorf("update proxy state: %v", err)
}
} else {
if err := m.stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Debugf("delete proxy state: %v", err)
}
}
}

View File

@@ -0,0 +1,45 @@
//go:build !darwin || ios
package proxy
import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Manager is a no-op proxy manager for non-macOS platforms.
type Manager struct{}
// NewManager creates a new proxy manager (no-op on non-macOS).
func NewManager(_ *statemanager.Manager) *Manager {
return &Manager{}
}
// EnableWebProxy is a no-op on non-macOS platforms.
func (m *Manager) EnableWebProxy(host string, port int) error {
return nil
}
// DisableWebProxy is a no-op on non-macOS platforms.
func (m *Manager) DisableWebProxy() error {
return nil
}
// SetAutoproxyURL is a no-op on non-macOS platforms.
func (m *Manager) SetAutoproxyURL(pacURL string) error {
return nil
}
// DisableAutoproxy is a no-op on non-macOS platforms.
func (m *Manager) DisableAutoproxy() error {
return nil
}
// IsEnabled always returns false on non-macOS platforms.
func (m *Manager) IsEnabled() bool {
return false
}
// Restore is a no-op on non-macOS platforms.
func (m *Manager) Restore(services []string) error {
return nil
}

View File

@@ -0,0 +1,88 @@
//go:build darwin && !ios
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetActiveNetworkServices(t *testing.T) {
services, err := GetActiveNetworkServices()
assert.NoError(t, err)
assert.NotEmpty(t, services, "should have at least one network service")
// Check that services don't contain invalid entries
for _, service := range services {
assert.NotEmpty(t, service)
assert.NotContains(t, service, "*")
}
}
func TestManager_EnableDisableWebProxy(t *testing.T) {
// Skip this test in CI as it requires admin privileges
if testing.Short() {
t.Skip("skipping proxy test in short mode")
}
m := NewManager(nil)
assert.NotNil(t, m)
assert.False(t, m.IsEnabled())
// This test would require admin privileges to actually enable the proxy
// So we just test the basic state management
}
func TestShutdownState_Name(t *testing.T) {
state := &ShutdownState{}
assert.Equal(t, "proxy_state", state.Name())
}
func TestShutdownState_Cleanup_EmptyServices(t *testing.T) {
state := &ShutdownState{
ModifiedServices: []string{},
}
err := state.Cleanup()
assert.NoError(t, err)
}
func TestContains(t *testing.T) {
tests := []struct {
s string
substr string
want bool
}{
{"Enabled: Yes", "Enabled: Yes", true},
{"Enabled: No", "Enabled: Yes", false},
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", "Enabled: Yes", true},
{"", "Enabled: Yes", false},
{"Enabled: Yes", "", true},
}
for _, tt := range tests {
t.Run(tt.s+"_"+tt.substr, func(t *testing.T) {
got := contains(tt.s, tt.substr)
assert.Equal(t, tt.want, got)
})
}
}
func TestIsProxyEnabled(t *testing.T) {
tests := []struct {
output string
want bool
}{
{"Enabled: Yes\nServer: 127.0.0.1\nPort: 8080", true},
{"Enabled: No\nServer: \nPort: 0", false},
{"Server: 127.0.0.1\nEnabled: Yes\nPort: 8080", true},
{"", false},
}
for _, tt := range tests {
t.Run(tt.output, func(t *testing.T) {
got := isProxyEnabled(tt.output)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,105 @@
//go:build darwin && !ios
package proxy
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// ShutdownState stores proxy state for cleanup on unclean shutdown.
type ShutdownState struct {
ModifiedServices []string `json:"modified_services"`
}
// Name returns the state name for persistence.
func (s *ShutdownState) Name() string {
return "proxy_state"
}
// Cleanup restores proxy settings after an unclean shutdown.
func (s *ShutdownState) Cleanup() error {
if len(s.ModifiedServices) == 0 {
return nil
}
log.Infof("cleaning up proxy state for %d services", len(s.ModifiedServices))
for _, service := range s.ModifiedServices {
// Disable web proxy (HTTP)
cmd := exec.Command(networksetupPath, "-setwebproxystate", service, "off")
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("cleanup web proxy for %s: %v, output: %s", service, err, out)
}
// Disable secure web proxy (HTTPS)
cmd = exec.Command(networksetupPath, "-setsecurewebproxystate", service, "off")
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("cleanup secure web proxy for %s: %v, output: %s", service, err, out)
}
// Disable autoproxy
cmd = exec.Command(networksetupPath, "-setautoproxystate", service, "off")
if out, err := cmd.CombinedOutput(); err != nil {
log.Warnf("cleanup autoproxy for %s: %v, output: %s", service, err, out)
}
log.Debugf("cleaned up proxy for service %s", service)
}
return nil
}
// RegisterState registers the proxy state with the state manager.
func RegisterState(stateManager *statemanager.Manager) {
if stateManager == nil {
return
}
stateManager.RegisterState(&ShutdownState{})
}
// GetProxyState returns the current proxy state from the command line.
func GetProxyState(service string) (webProxy, secureProxy, autoProxy bool, err error) {
// Check web proxy state
cmd := exec.Command(networksetupPath, "-getwebproxy", service)
out, err := cmd.Output()
if err != nil {
return false, false, false, fmt.Errorf("get web proxy: %w", err)
}
webProxy = isProxyEnabled(string(out))
// Check secure web proxy state
cmd = exec.Command(networksetupPath, "-getsecurewebproxy", service)
out, err = cmd.Output()
if err != nil {
return false, false, false, fmt.Errorf("get secure web proxy: %w", err)
}
secureProxy = isProxyEnabled(string(out))
// Check autoproxy state
cmd = exec.Command(networksetupPath, "-getautoproxyurl", service)
out, err = cmd.Output()
if err != nil {
return false, false, false, fmt.Errorf("get autoproxy: %w", err)
}
autoProxy = isProxyEnabled(string(out))
return webProxy, secureProxy, autoProxy, nil
}
func isProxyEnabled(output string) bool {
return !contains(output, "Enabled: No") && contains(output, "Enabled: Yes")
}
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

View File

@@ -0,0 +1,24 @@
//go:build !darwin || ios
package proxy
import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// ShutdownState is a no-op state for non-macOS platforms.
type ShutdownState struct{}
// Name returns the state name.
func (s *ShutdownState) Name() string {
return "proxy_state"
}
// Cleanup is a no-op on non-macOS platforms.
func (s *ShutdownState) Cleanup() error {
return nil
}
// RegisterState is a no-op on non-macOS platforms.
func RegisterState(stateManager *statemanager.Manager) {
}

View File

@@ -9,6 +9,8 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms")
}
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name")

View File

@@ -16,13 +16,13 @@ import (
"strings"
"syscall"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
)
@@ -78,9 +78,8 @@ var (
}
}
_, valid := dns.IsDomainName(dnsDomain)
if !valid || len(dnsDomain) > 192 {
return fmt.Errorf("failed parsing the provided dns-domain. Valid status: %t, Length: %d", valid, len(dnsDomain))
if !nbdomain.IsValidDomainNoWildcard(dnsDomain) {
return fmt.Errorf("invalid dns-domain: %s", dnsDomain)
}
return nil

View File

@@ -187,10 +187,10 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
}
}
}

View File

@@ -108,10 +108,19 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
if e, ok := status.FromError(err); ok && e.Type() == status.NotFound {
log.WithContext(ctx).Tracef("DeletePeers: peer %s not found, skipping", peerID)
return nil
}
return err
}
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
log.WithContext(ctx).Tracef("DeletePeers: peer %s skipped (connected=%t, lastSeen=%s, threshold=%s, ephemeral=%t)",
peerID, peer.Status.Connected,
peer.Status.LastSeen.Format(time.RFC3339),
time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)).Format(time.RFC3339),
peer.Ephemeral)
return nil
}
@@ -150,7 +159,8 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
return nil
})
if err != nil {
return err
log.WithContext(ctx).Errorf("DeletePeers: failed to delete peer %s: %v", peerID, err)
continue
}
if m.integratedPeerValidator != nil {

View File

@@ -6,7 +6,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -63,7 +63,7 @@ func (r *Record) Validate() error {
return errors.New("record name is required")
}
if !util.IsValidDomain(r.Name) {
if !domain.IsValidDomain(r.Name) {
return errors.New("invalid record name format")
}
@@ -81,8 +81,8 @@ func (r *Record) Validate() error {
return err
}
case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) {
return errors.New("invalid CNAME record format")
if !domain.IsValidDomainNoWildcard(r.Content) {
return errors.New("invalid CNAME target format")
}
default:
return errors.New("invalid record type, must be A, AAAA, or CNAME")

View File

@@ -6,7 +6,7 @@ import (
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -73,7 +73,7 @@ func (z *Zone) Validate() error {
return errors.New("zone name exceeds maximum length of 255 characters")
}
if !util.IsValidDomain(z.Domain) {
if !domain.IsValidDomainNoWildcard(z.Domain) {
return errors.New("invalid zone domain format")
}

View File

@@ -17,13 +17,14 @@ import (
pb "github.com/golang/protobuf/proto" // nolint
"github.com/golang/protobuf/ptypes/timestamp"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/netbird/shared/management/client/common"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
@@ -304,6 +305,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutines(ctx, accountID, peer)
return err
}

View File

@@ -26,6 +26,7 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
@@ -231,7 +232,7 @@ func BuildManager(
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
if am.singleAccountMode {
if !isDomainValid(singleAccountModeDomain) {
if !nbdomain.IsValidDomainNoWildcard(singleAccountModeDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
}
am.singleAccountModeDomain = singleAccountModeDomain
@@ -402,7 +403,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
if newSettings.DNSDomain != "" && !nbdomain.IsValidDomainNoWildcard(newSettings.DNSDomain) {
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
@@ -1691,10 +1692,12 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
return nil
}
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
// isDomainValid validates public/IDP domains using stricter rules than internal DNS domains.
// Requires at least 2-char alphabetic TLD and no single-label domains.
var publicDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
func isDomainValid(domain string) bool {
return invalidDomainRegexp.MatchString(domain)
return publicDomainRegexp.MatchString(domain)
}
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) {

View File

@@ -3,10 +3,10 @@ package server
import (
"context"
"errors"
"regexp"
"fmt"
"strings"
"unicode/utf8"
"github.com/miekg/dns"
"github.com/rs/xid"
nbdns "github.com/netbirdio/netbird/dns"
@@ -15,11 +15,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbdomain "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/status"
)
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
var errInvalidDomainName = errors.New("invalid domain name")
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
@@ -305,16 +304,18 @@ func validateGroups(list []string, groups map[string]*types.Group) error {
return nil
}
var domainMatcher = regexp.MustCompile(domainPattern)
func validateDomain(domain string) error {
if !domainMatcher.MatchString(domain) {
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
// validateDomain validates a nameserver match domain.
// Converts unicode to punycode. Wildcards are not allowed for nameservers.
func validateDomain(d string) error {
if strings.HasPrefix(d, "*.") {
return errors.New("wildcards not allowed")
}
_, valid := dns.IsDomainName(domain)
if !valid {
return errInvalidDomainName
// Nameservers allow trailing dot (FQDN format)
toValidate := strings.TrimSuffix(d, ".")
if _, err := nbdomain.ValidateDomains([]string{toValidate}); err != nil {
return fmt.Errorf("%w: %w", errInvalidDomainName, err)
}
return nil

View File

@@ -901,82 +901,53 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
return account, nil
}
// TestValidateDomain tests nameserver-specific domain validation.
// Core domain validation is tested in shared/management/domain/validate_test.go.
// This test only covers nameserver-specific behavior: wildcard rejection and unicode support.
func TestValidateDomain(t *testing.T) {
testCases := []struct {
name string
domain string
errFunc require.ErrorAssertionFunc
}{
// Nameserver-specific: wildcards not allowed
{
name: "Valid domain name with multiple labels",
domain: "123.example.com",
name: "Wildcard prefix rejected",
domain: "*.example.com",
errFunc: require.Error,
},
{
name: "Wildcard in middle rejected",
domain: "a.*.example.com",
errFunc: require.Error,
},
// Nameserver-specific: unicode converted to punycode
{
name: "Unicode domain converted to punycode",
domain: "münchen.de",
errFunc: require.NoError,
},
{
name: "Valid domain name with hyphen",
domain: "test-example.com",
name: "Unicode domain all labels",
domain: "中国.中国",
errFunc: require.NoError,
},
// Basic validation still works (delegates to shared validation)
{
name: "Valid multi-label domain",
domain: "example.com",
errFunc: require.NoError,
},
{
name: "Valid domain name with only one label",
domain: "example",
name: "Valid single label",
domain: "internal",
errFunc: require.NoError,
},
{
name: "Valid domain name with trailing dot",
domain: "example.",
errFunc: require.NoError,
},
{
name: "Invalid wildcard domain name",
domain: "*.example",
errFunc: require.Error,
},
{
name: "Invalid domain name with leading dot",
domain: ".com",
errFunc: require.Error,
},
{
name: "Invalid domain name with dot only",
domain: ".",
errFunc: require.Error,
},
{
name: "Invalid domain name with double hyphen",
domain: "test--example.com",
errFunc: require.Error,
},
{
name: "Invalid domain name with a label exceeding 63 characters",
domain: "dnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdnsdns.com",
errFunc: require.Error,
},
{
name: "Invalid domain name starting with a hyphen",
name: "Invalid leading hyphen",
domain: "-example.com",
errFunc: require.Error,
},
{
name: "Invalid domain name ending with a hyphen",
domain: "example.com-",
errFunc: require.Error,
},
{
name: "Invalid domain with unicode",
domain: "example?,.com",
errFunc: require.Error,
},
{
name: "Invalid domain with space before top-level domain",
domain: "space .example.com",
errFunc: require.Error,
},
{
name: "Invalid domain with trailing space",
domain: "example.com ",
errFunc: require.Error,
},
}
for _, testCase := range testCases {

View File

@@ -203,7 +203,7 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
NetworkID: "testNetworkId",
Name: "testResourceId",
Description: "description",
Address: "invalid-address",
Address: "-invalid",
}
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())
@@ -227,9 +227,9 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
resource := &types.NetworkResource{
AccountID: "testAccountId",
NetworkID: "testNetworkId",
Name: "testResourceId",
Name: "used-name",
Description: "description",
Address: "invalid-address",
Address: "example.com",
}
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir())

View File

@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/netip"
"regexp"
"github.com/rs/xid"
@@ -166,8 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix,
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
}
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
if domainRegex.MatchString(address) {
if _, err := nbDomain.ValidateDomains([]string{address}); err == nil {
return Domain, address, netip.Prefix{}, nil
}

View File

@@ -23,10 +23,12 @@ func TestGetResourceType(t *testing.T) {
{"example.com", Domain, false, "example.com", netip.Prefix{}},
{"*.example.com", Domain, false, "*.example.com", netip.Prefix{}},
{"sub.example.com", Domain, false, "sub.example.com", netip.Prefix{}},
{"example.x", Domain, false, "example.x", netip.Prefix{}},
{"internal", Domain, false, "internal", netip.Prefix{}},
// Invalid inputs
{"invalid", "", true, "", netip.Prefix{}},
{"1.1.1.1/abc", "", true, "", netip.Prefix{}},
{"1234", "", true, "", netip.Prefix{}},
{"-invalid.com", "", true, "", netip.Prefix{}},
{"", "", true, "", netip.Prefix{}},
}
for _, tt := range tests {

View File

@@ -728,11 +728,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed adding peer to All group: %w", err)
}
if temporary {
// we should track ephemeral peers to be able to clean them if the peer don't sync and be marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
if addedByUser {
err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin())
if err != nil {
@@ -760,6 +755,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
return fmt.Errorf("failed to increment network serial: %w", err)
}
if ephemeral {
// we should track ephemeral peers to be able to clean them if the peer doesn't sync and isn't marked as connected
am.networkMapController.TrackEphemeralPeer(ctx, newPeer)
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})

View File

@@ -1,9 +1,5 @@
package util
import "regexp"
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
// Difference returns the elements in `a` that aren't in `b`.
func Difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
@@ -55,9 +51,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
return false
}
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}

View File

@@ -10,7 +10,30 @@ const maxDomains = 32
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
// IsValidDomain checks if a single domain string is valid.
// Does not convert unicode to punycode - domain must already be ASCII/punycode.
// Allows wildcard prefix (*.example.com).
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(strings.ToLower(domain))
}
// IsValidDomainNoWildcard checks if a single domain string is valid without wildcard prefix.
// Use for zone domains and CNAME targets where wildcards are not allowed.
func IsValidDomainNoWildcard(domain string) bool {
if domain == "" {
return false
}
if strings.HasPrefix(domain, "*.") {
return false
}
return domainRegex.MatchString(strings.ToLower(domain))
}
// ValidateDomains validates domains and converts unicode to punycode.
// Allows wildcard prefix (*.example.com). Maximum 32 domains.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
@@ -37,7 +60,10 @@ func ValidateDomains(domains []string) (List, error) {
return domainList, nil
}
// ValidateDomainsList checks if each domain in the list is valid
// ValidateDomainsList validates domains without punycode conversion.
// Use this for domains that must already be in ASCII/punycode format (e.g., extra DNS labels).
// Unlike ValidateDomains, this does not convert unicode to punycode - unicode domains will fail.
// Allows wildcard prefix (*.example.com). Maximum 32 domains.
func ValidateDomainsList(domains []string) error {
if len(domains) == 0 {
return nil

View File

@@ -2,12 +2,16 @@ package domain
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateDomains(t *testing.T) {
label63 := strings.Repeat("a", 63)
label64 := strings.Repeat("a", 64)
tests := []struct {
name string
domains []string
@@ -26,6 +30,48 @@ func TestValidateDomains(t *testing.T) {
expected: List{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Valid uppercase domain normalized to lowercase",
domains: []string{"EXAMPLE.COM"},
expected: List{"example.com"},
wantErr: false,
},
{
name: "Valid mixed case domain",
domains: []string{"ExAmPlE.CoM"},
expected: List{"example.com"},
wantErr: false,
},
{
name: "Single letter TLD",
domains: []string{"example.x"},
expected: List{"example.x"},
wantErr: false,
},
{
name: "Two letter domain labels",
domains: []string{"a.b"},
expected: List{"a.b"},
wantErr: false,
},
{
name: "Single character domain",
domains: []string{"x"},
expected: List{"x"},
wantErr: false,
},
{
name: "Wildcard with single letter TLD",
domains: []string{"*.x"},
expected: List{"*.x"},
wantErr: false,
},
{
name: "Multi-level with single letter labels",
domains: []string{"a.b.c"},
expected: List{"a.b.c"},
wantErr: false,
},
{
name: "Valid Unicode domain",
domains: []string{"münchen.de"},
@@ -45,17 +91,92 @@ func TestValidateDomains(t *testing.T) {
wantErr: false,
},
{
name: "Invalid domain format",
name: "Valid domain starting with digit",
domains: []string{"123.example.com"},
expected: List{"123.example.com"},
wantErr: false,
},
// Numeric TLDs are allowed for internal/private DNS use cases.
// While ICANN doesn't issue all-numeric gTLDs, the DNS protocol permits them
// and resolvers like systemd-resolved handle them correctly.
{
name: "Numeric TLD allowed",
domains: []string{"example.123"},
expected: List{"example.123"},
wantErr: false,
},
{
name: "Single digit TLD allowed",
domains: []string{"example.1"},
expected: List{"example.1"},
wantErr: false,
},
{
name: "All numeric labels allowed",
domains: []string{"123.456"},
expected: List{"123.456"},
wantErr: false,
},
{
name: "Single numeric label allowed",
domains: []string{"123"},
expected: List{"123"},
wantErr: false,
},
{
name: "Valid domain with double hyphen",
domains: []string{"test--example.com"},
expected: List{"test--example.com"},
wantErr: false,
},
{
name: "Invalid leading hyphen",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid domain format 2",
name: "Invalid trailing hyphen",
domains: []string{"example.com-"},
expected: nil,
wantErr: true,
},
{
name: "Invalid leading dot",
domains: []string{".com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid dot only",
domains: []string{"."},
expected: nil,
wantErr: true,
},
{
name: "Invalid double dot",
domains: []string{"example..com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid special characters",
domains: []string{"example?,.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid space in domain",
domains: []string{"space .example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid trailing space",
domains: []string{"example.com "},
expected: nil,
wantErr: true,
},
{
name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
@@ -86,6 +207,30 @@ func TestValidateDomains(t *testing.T) {
expected: nil,
wantErr: true,
},
{
name: "Valid 63 char label (max)",
domains: []string{label63 + ".com"},
expected: List{Domain(label63 + ".com")},
wantErr: false,
},
{
name: "Invalid 64 char label (exceeds max)",
domains: []string{label64 + ".com"},
expected: nil,
wantErr: true,
},
{
name: "Valid 253 char domain (max)",
domains: []string{strings.Repeat("a.", 126) + "a"},
expected: List{Domain(strings.Repeat("a.", 126) + "a")},
wantErr: false,
},
{
name: "Invalid 254+ char domain (exceeds max)",
domains: []string{strings.Repeat("ab.", 85)},
expected: nil,
wantErr: true,
},
}
for _, tt := range tests {
@@ -118,6 +263,57 @@ func TestValidateDomainsList(t *testing.T) {
domains: []string{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Uppercase domain accepted",
domains: []string{"EXAMPLE.COM"},
wantErr: false,
},
{
name: "Single letter TLD",
domains: []string{"example.x"},
wantErr: false,
},
{
name: "Two letter domain labels",
domains: []string{"a.b"},
wantErr: false,
},
{
name: "Single character domain",
domains: []string{"x"},
wantErr: false,
},
{
name: "Wildcard with single letter TLD",
domains: []string{"*.x"},
wantErr: false,
},
{
name: "Multi-level with single letter labels",
domains: []string{"a.b.c"},
wantErr: false,
},
// Numeric TLDs are allowed for internal/private DNS use cases.
{
name: "Numeric TLD allowed",
domains: []string{"example.123"},
wantErr: false,
},
{
name: "Single digit TLD allowed",
domains: []string{"example.1"},
wantErr: false,
},
{
name: "All numeric labels allowed",
domains: []string{"123.456"},
wantErr: false,
},
{
name: "Single numeric label allowed",
domains: []string{"123"},
wantErr: false,
},
{
name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"},