diff --git a/client/internal/dns.go b/client/internal/dns.go index 3c68e4d00..f5040ee49 100644 --- a/client/internal/dns.go +++ b/client/internal/dns.go @@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple var records []nbdns.SimpleRecord for _, zone := range config.CustomZones { - if zone.SkipPTRProcess { + if zone.NonAuthoritative { continue } for _, record := range zone.Records { diff --git a/client/internal/dns/local/local.go b/client/internal/dns/local/local.go index cb1fa5293..63c2428ce 100644 --- a/client/internal/dns/local/local.go +++ b/client/internal/dns/local/local.go @@ -28,10 +28,11 @@ type resolver interface { } type Resolver struct { - mu sync.RWMutex - records map[dns.Question][]dns.RR - domains map[domain.Domain]struct{} - zones []domain.Domain + mu sync.RWMutex + records map[dns.Question][]dns.RR + domains map[domain.Domain]struct{} + // zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone) + zones map[domain.Domain]bool resolver resolver ctx context.Context @@ -43,6 +44,7 @@ func NewResolver() *Resolver { return &Resolver{ records: make(map[dns.Question][]dns.RR), domains: make(map[domain.Domain]struct{}), + zones: make(map[domain.Domain]bool), ctx: ctx, cancel: cancel, } @@ -67,7 +69,7 @@ func (d *Resolver) Stop() { maps.Clear(d.records) maps.Clear(d.domains) - d.zones = nil + maps.Clear(d.zones) } // ID returns the unique handler ID @@ -97,6 +99,11 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { replyMessage.Answer = result.records replyMessage.Rcode = d.determineRcode(question, result) + if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) { + d.continueToNext(logger, w, r) + return + } + if err := w.WriteMsg(replyMessage); err != nil { logger.Warnf("failed to write the local resolver response: %v", err) } @@ -120,6 +127,42 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in return dns.RcodeNameError } +// findZone finds the matching zone for a query name using reverse suffix lookup. +// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname. +func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) { + qname = strings.ToLower(dns.Fqdn(qname)) + for { + if nonAuth, ok := d.zones[domain.Domain(qname)]; ok { + return nonAuth, true + } + // Move to parent domain + idx := strings.Index(qname, ".") + if idx == -1 || idx == len(qname)-1 { + return false, false + } + qname = qname[idx+1:] + } +} + +// shouldFallthrough checks if the query should fallthrough to the next handler. +// Returns true if the queried name belongs to a non-authoritative zone. +func (d *Resolver) shouldFallthrough(qname string) bool { + d.mu.RLock() + defer d.mu.RUnlock() + + nonAuth, found := d.findZone(qname) + return found && nonAuth +} + +func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + logger.Warnf("failed to write continue signal: %v", err) + } +} + // hasRecordsForDomain checks if any records exist for the given domain name regardless of type func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool { d.mu.RLock() @@ -137,14 +180,8 @@ func (d *Resolver) isInManagedZone(name string) bool { d.mu.RLock() defer d.mu.RUnlock() - name = dns.Fqdn(name) - for _, zone := range d.zones { - zoneStr := dns.Fqdn(zone.PunycodeString()) - if strings.EqualFold(name, zoneStr) || strings.HasSuffix(strings.ToLower(name), strings.ToLower("."+zoneStr)) { - return true - } - } - return false + _, found := d.findZone(name) + return found } // lookupResult contains the result of a DNS lookup operation. @@ -343,21 +380,23 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16, } } -// Update updates the resolver with new records and zone information. -// The zones parameter specifies which DNS zones this resolver manages. -func (d *Resolver) Update(update []nbdns.SimpleRecord, zones []domain.Domain) { +// Update replaces all zones and their records +func (d *Resolver) Update(customZones []nbdns.CustomZone) { d.mu.Lock() defer d.mu.Unlock() maps.Clear(d.records) maps.Clear(d.domains) + maps.Clear(d.zones) - d.zones = zones + for _, zone := range customZones { + zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain))) + d.zones[zoneDomain] = zone.NonAuthoritative - for _, rec := range update { - if err := d.registerRecord(rec); err != nil { - log.Warnf("failed to register the record (%s): %v", rec, err) - continue + for _, rec := range zone.Records { + if err := d.registerRecord(rec); err != nil { + log.Warnf("failed to register the record (%s): %v", rec, err) + } } } } diff --git a/client/internal/dns/local/local_test.go b/client/internal/dns/local/local_test.go index 2f8e08b1a..1c7cad5d1 100644 --- a/client/internal/dns/local/local_test.go +++ b/client/internal/dns/local/local_test.go @@ -16,7 +16,6 @@ import ( "github.com/netbirdio/netbird/client/internal/dns/test" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/domain" ) // mockResolver implements resolver for testing @@ -125,11 +124,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) { resolver := NewResolver() - update1 := []nbdns.SimpleRecord{record1} - update2 := []nbdns.SimpleRecord{record2} + zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}} + zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}} // Apply first update - resolver.Update(update1, nil) + resolver.Update(zone1) // Verify first update resolver.mu.RLock() @@ -141,7 +140,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) { assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData) // Apply second update - resolver.Update(update2, nil) + resolver.Update(zone2) // Verify second update resolver.mu.RLock() @@ -170,10 +169,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) { Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2", } - update := []nbdns.SimpleRecord{record1, record2} + zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}} // Apply update with both records - resolver.Update(update, nil) + resolver.Update(zones) // Create question that matches both records question := dns.Question{ @@ -214,10 +213,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) { Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3", } - update := []nbdns.SimpleRecord{record1, record2, record3} + zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}} // Apply update with all three records - resolver.Update(update, nil) + resolver.Update(zones) msg := new(dns.Msg).SetQuestion(recordName, recordType) @@ -283,7 +282,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) { } // Update resolver with the records - resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}, nil) + resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}}) testCases := []struct { name string @@ -398,7 +397,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) { } // Update resolver with both records - resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}, nil) + resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}}) testCases := []struct { name string @@ -526,7 +525,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { RData: "target.example.com.", } - resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}, nil) + resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}}) testCases := []struct { name string @@ -620,10 +619,13 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) { func TestLocalResolver_CNAMEChainResolution(t *testing.T) { t.Run("simple internal CNAME chain", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, - {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."}, + {Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA) var resp *dns.Msg @@ -644,11 +646,14 @@ func TestLocalResolver_CNAMEChainResolution(t *testing.T) { t.Run("multi-hop CNAME chain", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."}, - {Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."}, - {Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."}, + {Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."}, + {Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA) var resp *dns.Msg @@ -661,9 +666,12 @@ func TestLocalResolver_CNAMEChainResolution(t *testing.T) { t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) var resp *dns.Msg @@ -695,7 +703,7 @@ func TestLocalResolver_CNAMEMaxDepth(t *testing.T) { Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10", }) - resolver.Update(records, nil) + resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}}) msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA) var resp *dns.Msg @@ -723,7 +731,7 @@ func TestLocalResolver_CNAMEMaxDepth(t *testing.T) { Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10", }) - resolver.Update(records, nil) + resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}}) msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA) var resp *dns.Msg @@ -736,10 +744,13 @@ func TestLocalResolver_CNAMEMaxDepth(t *testing.T) { t.Run("circular CNAME is protected by max depth", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."}, - {Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."}, + {Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."}, + }, + }}) msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA) var resp *dns.Msg @@ -763,9 +774,12 @@ func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) var resp *dns.Msg @@ -794,9 +808,12 @@ func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA) var resp *dns.Msg @@ -825,9 +842,12 @@ func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) var wg sync.WaitGroup results := make([]*dns.Msg, 10) @@ -856,10 +876,12 @@ func TestLocalResolver_ZoneManagement(t *testing.T) { t.Run("Update sets zones correctly", func(t *testing.T) { resolver := NewResolver() - zones := []domain.Domain{"example.com", "test.local"} - resolver.Update([]nbdns.SimpleRecord{ - {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - }, zones) + resolver.Update([]nbdns.CustomZone{ + {Domain: "example.com.", Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }}, + {Domain: "test.local."}, + }) assert.True(t, resolver.isInManagedZone("host.example.com.")) assert.True(t, resolver.isInManagedZone("other.example.com.")) @@ -869,7 +891,7 @@ func TestLocalResolver_ZoneManagement(t *testing.T) { t.Run("isInManagedZone case insensitive", func(t *testing.T) { resolver := NewResolver() - resolver.Update(nil, []domain.Domain{"Example.COM"}) + resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}}) assert.True(t, resolver.isInManagedZone("host.example.com.")) assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM.")) @@ -877,10 +899,10 @@ func TestLocalResolver_ZoneManagement(t *testing.T) { t.Run("Update clears zones", func(t *testing.T) { resolver := NewResolver() - resolver.Update(nil, []domain.Domain{"example.com"}) + resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}}) assert.True(t, resolver.isInManagedZone("host.example.com.")) - resolver.Update(nil, nil) + resolver.Update(nil) assert.False(t, resolver.isInManagedZone("host.example.com.")) }) } @@ -889,9 +911,12 @@ func TestLocalResolver_ZoneManagement(t *testing.T) { func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."}, - }, []domain.Domain{"myzone.test"}) + resolver.Update([]nbdns.CustomZone{{ + Domain: "myzone.test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA) var resp *dns.Msg @@ -913,9 +938,12 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."}, - }, []domain.Domain{"myzone.test"}) + resolver.Update([]nbdns.CustomZone{{ + Domain: "myzone.test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA) var resp *dns.Msg @@ -929,10 +957,13 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) { resolver := NewResolver() // CNAME points to target that has A but no AAAA - query for AAAA should be NODATA - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."}, - {Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"}, - }, []domain.Domain{"myzone.test"}) + resolver.Update([]nbdns.CustomZone{{ + Domain: "myzone.test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."}, + {Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA) var resp *dns.Msg @@ -963,9 +994,12 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA) var resp *dns.Msg @@ -1035,9 +1069,12 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { resolver := NewResolver() resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc} - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) var resp *dns.Msg @@ -1054,13 +1091,112 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) { } } +// TestLocalResolver_Fallthrough verifies that non-authoritative zones +// trigger fallthrough (Zero bit set) when no records match +func TestLocalResolver_Fallthrough(t *testing.T) { + resolver := NewResolver() + + record := nbdns.SimpleRecord{ + Name: "existing.custom.zone.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "10.0.0.1", + } + + testCases := []struct { + name string + zones []nbdns.CustomZone + queryName string + expectFallthrough bool + expectRecord bool + }{ + { + name: "Authoritative zone returns NXDOMAIN without fallthrough", + zones: []nbdns.CustomZone{{ + Domain: "custom.zone.", + Records: []nbdns.SimpleRecord{record}, + }}, + queryName: "nonexistent.custom.zone.", + expectFallthrough: false, + expectRecord: false, + }, + { + name: "Non-authoritative zone triggers fallthrough", + zones: []nbdns.CustomZone{{ + Domain: "custom.zone.", + Records: []nbdns.SimpleRecord{record}, + NonAuthoritative: true, + }}, + queryName: "nonexistent.custom.zone.", + expectFallthrough: true, + expectRecord: false, + }, + { + name: "Record found in non-authoritative zone returns normally", + zones: []nbdns.CustomZone{{ + Domain: "custom.zone.", + Records: []nbdns.SimpleRecord{record}, + NonAuthoritative: true, + }}, + queryName: "existing.custom.zone.", + expectFallthrough: false, + expectRecord: true, + }, + { + name: "Record found in authoritative zone returns normally", + zones: []nbdns.CustomZone{{ + Domain: "custom.zone.", + Records: []nbdns.SimpleRecord{record}, + }}, + queryName: "existing.custom.zone.", + expectFallthrough: false, + expectRecord: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resolver.Update(tc.zones) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA) + resolver.ServeDNS(responseWriter, msg) + + require.NotNil(t, responseMSG, "Should have received a response") + + if tc.expectFallthrough { + assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough") + assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN") + } else { + assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set") + } + + if tc.expectRecord { + assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records") + assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode) + } + }) + } +} + // TestLocalResolver_AuthoritativeFlag tests the AA flag behavior func TestLocalResolver_AuthoritativeFlag(t *testing.T) { t.Run("direct record lookup is authoritative", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - }, []domain.Domain{"example.com"}) + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA) var resp *dns.Msg @@ -1081,9 +1217,12 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA) var resp *dns.Msg @@ -1099,9 +1238,12 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) { func TestLocalResolver_Stop(t *testing.T) { t.Run("Stop clears all state", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - }, []domain.Domain{"example.com"}) + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) resolver.Stop() @@ -1116,9 +1258,12 @@ func TestLocalResolver_Stop(t *testing.T) { t.Run("Stop is safe to call multiple times", func(t *testing.T) { resolver := NewResolver() - resolver.Update([]nbdns.SimpleRecord{ - {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, - }, []domain.Domain{"example.com"}) + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + {Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + }, + }}) resolver.Stop() resolver.Stop() @@ -1140,9 +1285,12 @@ func TestLocalResolver_Stop(t *testing.T) { }, } - resolver.Update([]nbdns.SimpleRecord{ - {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, - }, nil) + resolver.Update([]nbdns.CustomZone{{ + Domain: "test.", + Records: []nbdns.SimpleRecord{ + {Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."}, + }, + }}) done := make(chan struct{}) go func() { @@ -1167,3 +1315,107 @@ func TestLocalResolver_Stop(t *testing.T) { } }) } + +// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough +func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) { + resolver := NewResolver() + + resolver.Update([]nbdns.CustomZone{{ + Domain: "EXAMPLE.COM.", + Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}}, + NonAuthoritative: true, + }}) + + var responseMSG *dns.Msg + responseWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + responseMSG = m + return nil + }, + } + + msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA) + resolver.ServeDNS(responseWriter, msg) + + require.NotNil(t, responseMSG) + assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match") +} + +// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label) +func BenchmarkFindZone_BestCase(b *testing.B) { + resolver := NewResolver() + + // Single zone that matches immediately + resolver.Update([]nbdns.CustomZone{{ + Domain: "example.com.", + NonAuthoritative: true, + }}) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + resolver.shouldFallthrough("example.com.") + } +} + +// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels +func BenchmarkFindZone_WorstCase(b *testing.B) { + resolver := NewResolver() + + // 100 zones that won't match + var zones []nbdns.CustomZone + for i := 0; i < 100; i++ { + zones = append(zones, nbdns.CustomZone{ + Domain: fmt.Sprintf("zone%d.internal.", i), + NonAuthoritative: true, + }) + } + resolver.Update(zones) + + // Query with many labels that won't match any zone + qname := "a.b.c.d.e.f.g.h.external.com." + + b.ResetTimer() + for i := 0; i < b.N; i++ { + resolver.shouldFallthrough(qname) + } +} + +// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match +func BenchmarkFindZone_TypicalCase(b *testing.B) { + resolver := NewResolver() + + // Typical setup: peer zone (authoritative) + one user zone (non-authoritative) + resolver.Update([]nbdns.CustomZone{ + {Domain: "netbird.cloud.", NonAuthoritative: false}, + {Domain: "custom.local.", NonAuthoritative: true}, + }) + + // Query for subdomain of user zone + qname := "myhost.custom.local." + + b.ResetTimer() + for i := 0; i < b.N; i++ { + resolver.shouldFallthrough(qname) + } +} + +// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones +func BenchmarkIsInManagedZone_ManyZones(b *testing.B) { + resolver := NewResolver() + + var zones []nbdns.CustomZone + for i := 0; i < 100; i++ { + zones = append(zones, nbdns.CustomZone{ + Domain: fmt.Sprintf("zone%d.internal.", i), + }) + } + resolver.Update(zones) + + // Query that matches zone50 + qname := "host.zone50.internal." + + b.ResetTimer() + for i := 0; i < b.N; i++ { + resolver.isInManagedZone(qname) + } +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 0a56b92a1..29bb7f3dc 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { } } - localMuxUpdates, localRecords, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones) + localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones) if err != nil { return fmt.Errorf("local handler updater: %w", err) } @@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.updateMux(muxUpdates) - // register local records - s.localResolver.Update(localRecords, localZones) + s.localResolver.Update(localZones) s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort()) @@ -659,10 +658,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback) } -func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, []domain.Domain, error) { +func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) { var muxUpdates []handlerWrapper - var localRecords []nbdns.SimpleRecord - var zones []domain.Domain + var zones []nbdns.CustomZone for _, customZone := range customZones { if len(customZone.Records) == 0 { @@ -676,19 +674,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) priority: PriorityLocal, }) - zones = append(zones, domain.Domain(customZone.Domain)) - + // zone records contain the fqdn, so we can just flatten them + var localRecords []nbdns.SimpleRecord for _, record := range customZone.Records { if record.Class != nbdns.DefaultClass { log.Warnf("received an invalid class type: %s", record.Class) continue } - // zone records contain the fqdn, so we can just flatten them localRecords = append(localRecords, record) } + customZone.Records = localRecords + zones = append(zones, customZone) } - return muxUpdates, localRecords, zones, nil + return muxUpdates, zones, nil } func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 2b5b460b4..200a5f496 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -128,7 +128,7 @@ func TestUpdateDNSServer(t *testing.T) { testCases := []struct { name string initUpstreamMap registeredHandlerMap - initLocalRecords []nbdns.SimpleRecord + initLocalZones []nbdns.CustomZone initSerial uint64 inputSerial uint64 inputUpdate nbdns.Config @@ -181,7 +181,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "New Config Should Succeed", - initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, + initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initUpstreamMap: registeredHandlerMap{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: "netbird.cloud", @@ -222,7 +222,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "Smaller Config Serial Should Be Skipped", - initLocalRecords: []nbdns.SimpleRecord{}, + initLocalZones: []nbdns.CustomZone{}, initUpstreamMap: make(registeredHandlerMap), initSerial: 2, inputSerial: 1, @@ -230,7 +230,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "Empty NS Group Domain Or Not Primary Element Should Fail", - initLocalRecords: []nbdns.SimpleRecord{}, + initLocalZones: []nbdns.CustomZone{}, initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, @@ -252,7 +252,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "Invalid NS Group Nameservers list Should Fail", - initLocalRecords: []nbdns.SimpleRecord{}, + initLocalZones: []nbdns.CustomZone{}, initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, @@ -274,7 +274,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "Invalid Custom Zone Records list Should Skip", - initLocalRecords: []nbdns.SimpleRecord{}, + initLocalZones: []nbdns.CustomZone{}, initUpstreamMap: make(registeredHandlerMap), initSerial: 0, inputSerial: 1, @@ -300,7 +300,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "Empty Config Should Succeed and Clean Maps", - initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, + initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initUpstreamMap: registeredHandlerMap{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, @@ -316,7 +316,7 @@ func TestUpdateDNSServer(t *testing.T) { }, { name: "Disabled Service Should clean map", - initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, + initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initUpstreamMap: registeredHandlerMap{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, @@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) { }() dnsServer.dnsMuxMap = testCase.initUpstreamMap - dnsServer.localResolver.Update(testCase.initLocalRecords, nil) + dnsServer.localResolver.Update(testCase.initLocalZones) dnsServer.updateSerial = testCase.initSerial err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate) @@ -510,8 +510,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { priority: PriorityUpstream, }, } - //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} - dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, nil) + dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}) dnsServer.updateSerial = 0 nameServers := []nbdns.NameServer{ @@ -2013,7 +2012,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) { }, } - localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) assert.NoError(t, err) upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups) @@ -2074,7 +2073,7 @@ func TestLocalResolverPriorityConstants(t *testing.T) { }, } - localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) assert.NoError(t, err) assert.Len(t, localMuxUpdates, 1) assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 6b52010fb..c997acc75 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -202,6 +202,10 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn resutil.SetMeta(w, "upstream", upstream.String()) + // Clear Zero bit from external responses to prevent upstream servers from + // manipulating our internal fallthrough signaling mechanism + rm.MsgHdr.Zero = false + if err := w.WriteMsg(rm); err != nil { logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err) return true diff --git a/client/internal/engine.go b/client/internal/engine.go index 4f18c3bc8..2acd86a16 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1251,11 +1251,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns ForwarderPort: forwarderPort, } - for _, zone := range protoDNSConfig.GetCustomZones() { + protoZones := protoDNSConfig.GetCustomZones() + // Treat single zone as authoritative for backward compatibility with old servers + // that only send the peer FQDN zone without setting field 4. + singleZoneCompat := len(protoZones) == 1 + + for _, zone := range protoZones { dnsZone := nbdns.CustomZone{ Domain: zone.GetDomain(), SearchDomainDisabled: zone.GetSearchDomainDisabled(), - SkipPTRProcess: zone.GetSkipPTRProcess(), + NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat, } for _, record := range zone.Records { dnsRecord := nbdns.SimpleRecord{ diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 928b85acb..c7ec47da4 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -325,6 +325,10 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log. return fmt.Errorf("received nil DNS message") } + // Clear Zero bit from peer responses to prevent external sources from + // manipulating our internal fallthrough signaling mechanism + r.MsgHdr.Zero = false + if len(r.Answer) > 0 && len(r.Question) > 0 { origPattern := "" if writer, ok := w.(*nbdns.ResponseWriterChain); ok { diff --git a/dns/dns.go b/dns/dns.go index aa0e16eb1..c43e5de00 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -47,8 +47,8 @@ type CustomZone struct { Records []SimpleRecord // SearchDomainDisabled indicates whether to add match domains to a search domains list or not SearchDomainDisabled bool - // SkipPTRProcess indicates whether a client should process PTR records from custom zones - SkipPTRProcess bool + // NonAuthoritative marks user-created zones + NonAuthoritative bool } // SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 455e6bd58..c4d2e92f9 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -374,8 +374,9 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool { // Helper function to convert nbdns.CustomZone to proto.CustomZone func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { protoZone := &proto.CustomZone{ - Domain: zone.Domain, - Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + NonAuthoritative: zone.NonAuthoritative, } for _, record := range zone.Records { protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ diff --git a/shared/management/proto/management.pb.go b/shared/management/proto/management.pb.go index 2047c51ea..077f84ed3 100644 --- a/shared/management/proto/management.pb.go +++ b/shared/management/proto/management.pb.go @@ -2873,7 +2873,7 @@ type CustomZone struct { Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"` Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"` SearchDomainDisabled bool `protobuf:"varint,3,opt,name=SearchDomainDisabled,proto3" json:"SearchDomainDisabled,omitempty"` - SkipPTRProcess bool `protobuf:"varint,4,opt,name=SkipPTRProcess,proto3" json:"SkipPTRProcess,omitempty"` + NonAuthoritative bool `protobuf:"varint,4,opt,name=NonAuthoritative,proto3" json:"NonAuthoritative,omitempty"` } func (x *CustomZone) Reset() { @@ -2929,9 +2929,9 @@ func (x *CustomZone) GetSearchDomainDisabled() bool { return false } -func (x *CustomZone) GetSkipPTRProcess() bool { +func (x *CustomZone) GetNonAuthoritative() bool { if x != nil { - return x.SkipPTRProcess + return x.NonAuthoritative } return false } diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index f2e591e88..c4cc43295 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -464,7 +464,9 @@ message CustomZone { string Domain = 1; repeated SimpleRecord Records = 2; bool SearchDomainDisabled = 3; - bool SkipPTRProcess = 4; + // NonAuthoritative indicates this is a user-created zone (not the built-in peer DNS zone). + // Non-authoritative zones will fallthrough to lower-priority handlers on NXDOMAIN and skip PTR processing. + bool NonAuthoritative = 4; } // SimpleRecord represents a dns.SimpleRecord