From d1ead2265ba114e84ce9fec04e9abf01525cecc8 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 18 Feb 2026 19:14:09 +0100 Subject: [PATCH] [client] Batch macOS DNS domains to avoid truncation (#5368) * [client] Batch macOS DNS domains across multiple scutil keys to avoid truncation scutil has undocumented limits: 99-element cap on d.add arrays and ~2048 byte value buffer for SupplementalMatchDomains. Users with 60+ domains hit silent domain loss. This applies the same batching approach used on Windows (nrptMaxDomainsPerRule=50), splitting domains into indexed resolver keys (NetBird-Match-0, NetBird-Match-1, etc.) with 50-element and 1500-byte limits per key. * check for all keys on getRemovableKeysWithDefaults * use multi error --- client/internal/dns/host_darwin.go | 171 +++++++++++++---- client/internal/dns/host_darwin_test.go | 238 +++++++++++++++++++++++- 2 files changed, 360 insertions(+), 49 deletions(-) diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index af84c8a85..b3908f163 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -14,6 +14,8 @@ import ( "strings" "sync" + "github.com/hashicorp/go-multierror" + nberrors "github.com/netbirdio/netbird/client/errors" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -22,6 +24,7 @@ import ( const ( netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS" + netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS" globalIPv4State = "State:/Network/Global/IPv4" primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS" keySupplementalMatchDomains = "SupplementalMatchDomains" @@ -35,6 +38,14 @@ const ( searchSuffix = "Search" matchSuffix = "Match" localSuffix = "Local" + + // maxDomainsPerResolverEntry is the max number of domains per scutil resolver key. + // scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap. + maxDomainsPerResolverEntry = 50 + + // maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key. + // scutil has an undocumented ~2048 byte value buffer; we stay well under it. + maxDomainBytesPerResolverEntry = 1500 ) type systemConfigurator struct { @@ -84,28 +95,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) } - matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) - var err error - if len(matchDomains) != 0 { - err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) - } else { - log.Infof("removing match domains from the system") - err = s.removeKeyFromSystemConfig(matchKey) + if err := s.removeKeysContaining(matchSuffix); err != nil { + log.Warnf("failed to remove old match keys: %v", err) } - if err != nil { - return fmt.Errorf("add match domains: %w", err) + if len(matchDomains) != 0 { + if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil { + return fmt.Errorf("add match domains: %w", err) + } } s.updateState(stateManager) - searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) - if len(searchDomains) != 0 { - err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort) - } else { - log.Infof("removing search domains from the system") - err = s.removeKeyFromSystemConfig(searchKey) + if err := s.removeKeysContaining(searchSuffix); err != nil { + log.Warnf("failed to remove old search keys: %v", err) } - if err != nil { - return fmt.Errorf("add search domains: %w", err) + if len(searchDomains) != 0 { + if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil { + return fmt.Errorf("add search domains: %w", err) + } } s.updateState(stateManager) @@ -149,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error { func (s *systemConfigurator) getRemovableKeysWithDefaults() []string { if len(s.createdKeys) == 0 { - // return defaults for startup calls - return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)} + return s.discoverExistingKeys() } keys := make([]string, 0, len(s.createdKeys)) @@ -160,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string { return keys } +// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist. +// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown). +func (s *systemConfigurator) discoverExistingKeys() []string { + dnsKeys, err := getSystemDNSKeys() + if err != nil { + log.Errorf("failed to get system DNS keys: %v", err) + return nil + } + + var keys []string + + for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} { + key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix) + if strings.Contains(dnsKeys, key) { + keys = append(keys, key) + } + } + + for _, suffix := range []string{searchSuffix, matchSuffix} { + for i := 0; ; i++ { + key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i) + if !strings.Contains(dnsKeys, key) { + break + } + keys = append(keys, key) + } + } + + return keys +} + +// getSystemDNSKeys gets all DNS keys +func getSystemDNSKeys() (string, error) { + command := "list .*DNS\nquit\n" + out, err := runSystemConfigCommand(command) + if err != nil { + return "", err + } + return string(out), nil +} + func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { line := buildRemoveKeyOperation(key) _, err := runSystemConfigCommand(wrapCommand(line)) @@ -184,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error { return nil } - if err := s.addSearchDomains( - localKey, - strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, - ); err != nil { - return fmt.Errorf("add search domains: %w", err) + domainsStr := strings.Join(s.systemDNSSettings.Domains, " ") + if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil { + return fmt.Errorf("add local dns state: %w", err) } + s.createdKeys[localKey] = struct{}{} return nil } @@ -280,28 +325,77 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr { return slices.Clone(s.origNameservers) } -func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { - err := s.addDNSState(key, domains, ip, port, true) - if err != nil { - return fmt.Errorf("add dns state: %w", err) +// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits. +func splitDomainsIntoBatches(domains []string) [][]string { + if len(domains) == 0 { + return nil } - log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) + var batches [][]string + var current []string + currentBytes := 0 - s.createdKeys[key] = struct{}{} + for _, d := range domains { + domainLen := len(d) + newBytes := currentBytes + domainLen + if currentBytes > 0 { + newBytes++ // space separator + } - return nil + if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) { + batches = append(batches, current) + current = nil + currentBytes = 0 + } + + current = append(current, d) + if currentBytes > 0 { + currentBytes += 1 + domainLen + } else { + currentBytes = domainLen + } + } + + if len(current) > 0 { + batches = append(batches, current) + } + + return batches } -func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error { - err := s.addDNSState(key, domains, dnsServer, port, false) - if err != nil { - return fmt.Errorf("add dns state: %w", err) +// removeKeysContaining removes all created keys that contain the given substring. +func (s *systemConfigurator) removeKeysContaining(suffix string) error { + var toRemove []string + for key := range s.createdKeys { + if strings.Contains(key, suffix) { + toRemove = append(toRemove, key) + } + } + var multiErr *multierror.Error + for _, key := range toRemove { + if err := s.removeKeyFromSystemConfig(key); err != nil { + multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err)) + } + } + return nberrors.FormatErrorOrNil(multiErr) +} + +// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch. +func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error { + batches := splitDomainsIntoBatches(domains) + + for i, batch := range batches { + key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i) + domainsStr := strings.Join(batch, " ") + + if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil { + return fmt.Errorf("add dns state for batch %d: %w", i, err) + } + + s.createdKeys[key] = struct{}{} } - log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains) - - s.createdKeys[key] = struct{}{} + log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches)) return nil } @@ -364,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error { if out, err := cmd.CombinedOutput(); err != nil { return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out) } - log.Info("flushed DNS cache") return nil } diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go index 28915de65..94d020c39 100644 --- a/client/internal/dns/host_darwin_test.go +++ b/client/internal/dns/host_darwin_test.go @@ -3,7 +3,10 @@ package dns import ( + "bufio" + "bytes" "context" + "fmt" "net/netip" "os/exec" "path/filepath" @@ -49,17 +52,22 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { require.NoError(t, sm.PersistState(context.Background())) - searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) - matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + // Collect all created keys for cleanup verification + createdKeys := make([]string, 0, len(configurator.createdKeys)) + for key := range configurator.createdKeys { + createdKeys = append(createdKeys, key) + } + defer func() { - for _, key := range []string{searchKey, matchKey, localKey} { + for _, key := range createdKeys { _ = removeTestDNSKey(key) } + _ = removeTestDNSKey(localKey) }() - for _, key := range []string{searchKey, matchKey, localKey} { + for _, key := range createdKeys { exists, err := checkDNSKeyExists(key) require.NoError(t, err) if exists { @@ -83,13 +91,223 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { err = shutdownState.Cleanup() require.NoError(t, err) - for _, key := range []string{searchKey, matchKey, localKey} { + for _, key := range createdKeys { exists, err := checkDNSKeyExists(key) require.NoError(t, err) assert.False(t, exists, "Key %s should NOT exist after cleanup", key) } } +// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc. +func generateShortDomains(count int) []string { + domains := make([]string, 0, count) + for i := range count { + label := "" + n := i + for { + label = string(rune('a'+n%26)) + label + n = n/26 - 1 + if n < 0 { + break + } + } + domains = append(domains, label+".com") + } + return domains +} + +// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com +func generateLongDomains(count int) []string { + domains := make([]string, 0, count) + for i := range count { + domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i)) + } + return domains +} + +// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key. +func readDomainsFromKey(t *testing.T, key string) []string { + t.Helper() + + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key)) + out, err := cmd.Output() + require.NoError(t, err, "scutil show should succeed") + + var domains []string + inArray := false + scanner := bufio.NewScanner(bytes.NewReader(out)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "") { + inArray = true + continue + } + if inArray { + if line == "}" { + break + } + // lines look like: "0 : a.com" + parts := strings.SplitN(line, " : ", 2) + if len(parts) == 2 { + domains = append(domains, parts[1]) + } + } + } + require.NoError(t, scanner.Err()) + return domains +} + +func TestSplitDomainsIntoBatches(t *testing.T) { + tests := []struct { + name string + domains []string + expectedCount int + checkAllPresent bool + }{ + { + name: "empty", + domains: nil, + expectedCount: 0, + }, + { + name: "under_limit", + domains: generateShortDomains(10), + expectedCount: 1, + checkAllPresent: true, + }, + { + name: "at_element_limit", + domains: generateShortDomains(50), + expectedCount: 1, + checkAllPresent: true, + }, + { + name: "over_element_limit", + domains: generateShortDomains(51), + expectedCount: 2, + checkAllPresent: true, + }, + { + name: "triple_element_limit", + domains: generateShortDomains(150), + expectedCount: 3, + checkAllPresent: true, + }, + { + name: "long_domains_hit_byte_limit", + domains: generateLongDomains(50), + checkAllPresent: true, + }, + { + name: "500_short_domains", + domains: generateShortDomains(500), + expectedCount: 10, + checkAllPresent: true, + }, + { + name: "500_long_domains", + domains: generateLongDomains(500), + checkAllPresent: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + batches := splitDomainsIntoBatches(tc.domains) + + if tc.expectedCount > 0 { + assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount) + } + + // Verify each batch respects limits + for i, batch := range batches { + assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry, + "batch %d exceeds element limit", i) + + totalBytes := 0 + for j, d := range batch { + if j > 0 { + totalBytes++ + } + totalBytes += len(d) + } + assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry, + "batch %d exceeds byte limit (%d bytes)", i, totalBytes) + } + + if tc.checkAllPresent { + var all []string + for _, batch := range batches { + all = append(all, batch...) + } + assert.Equal(t, tc.domains, all, "all domains should be present in order") + } + }) + } +} + +// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism +// and verifies all domains are readable across multiple scutil keys. +func TestMatchDomainBatching(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + testCases := []struct { + name string + count int + generator func(int) []string + }{ + {"short_10", 10, generateShortDomains}, + {"short_50", 50, generateShortDomains}, + {"short_100", 100, generateShortDomains}, + {"short_200", 200, generateShortDomains}, + {"short_500", 500, generateShortDomains}, + {"long_10", 10, generateLongDomains}, + {"long_50", 50, generateLongDomains}, + {"long_100", 100, generateLongDomains}, + {"long_200", 200, generateLongDomains}, + {"long_500", 500, generateLongDomains}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + defer func() { + for key := range configurator.createdKeys { + _ = removeTestDNSKey(key) + } + }() + + domains := tc.generator(tc.count) + err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false) + require.NoError(t, err) + + batches := splitDomainsIntoBatches(domains) + t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches)) + + // Read back all domains from all batched keys + var got []string + for i := range batches { + key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i) + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + require.True(t, exists, "key %s should exist", key) + + got = append(got, readDomainsFromKey(t, key)...) + } + + t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches)) + assert.Equal(t, tc.count, len(got), "all domains should be readable") + assert.Equal(t, domains, got, "domains should match in order") + }) + } +} + func checkDNSKeyExists(key string) (bool, error) { cmd := exec.Command(scutilPath) cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") @@ -158,15 +376,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man createdKeys: make(map[string]struct{}), } - searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) - matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) - localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - cleanup := func() { _ = sm.Stop(context.Background()) - for _, key := range []string{searchKey, matchKey, localKey} { + for key := range configurator.createdKeys { _ = removeTestDNSKey(key) } + // Also clean up old-format keys and local key in case they exist + _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)) + _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)) + _ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)) } return configurator, sm, cleanup