[client] Fall through dns chain for custom dns zones (#5081)

This commit is contained in:
Viktor Liu
2026-01-12 20:56:39 +08:00
committed by GitHub
parent 394ad19507
commit b12c084a50
12 changed files with 437 additions and 132 deletions

View File

@@ -76,7 +76,7 @@ func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.Simple
var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones {
if zone.SkipPTRProcess {
if zone.NonAuthoritative {
continue
}
for _, record := range zone.Records {

View File

@@ -28,10 +28,11 @@ type resolver interface {
}
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
zones []domain.Domain
mu sync.RWMutex
records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
zones map[domain.Domain]bool
resolver resolver
ctx context.Context
@@ -43,6 +44,7 @@ func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
zones: make(map[domain.Domain]bool),
ctx: ctx,
cancel: cancel,
}
@@ -67,7 +69,7 @@ func (d *Resolver) Stop() {
maps.Clear(d.records)
maps.Clear(d.domains)
d.zones = nil
maps.Clear(d.zones)
}
// ID returns the unique handler ID
@@ -97,6 +99,11 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.Answer = result.records
replyMessage.Rcode = d.determineRcode(question, result)
if replyMessage.Rcode == dns.RcodeNameError && d.shouldFallthrough(question.Name) {
d.continueToNext(logger, w, r)
return
}
if err := w.WriteMsg(replyMessage); err != nil {
logger.Warnf("failed to write the local resolver response: %v", err)
}
@@ -120,6 +127,42 @@ func (d *Resolver) determineRcode(question dns.Question, result lookupResult) in
return dns.RcodeNameError
}
// findZone finds the matching zone for a query name using reverse suffix lookup.
// Returns (nonAuthoritative, found). This is O(k) where k = number of labels in qname.
func (d *Resolver) findZone(qname string) (nonAuthoritative bool, found bool) {
qname = strings.ToLower(dns.Fqdn(qname))
for {
if nonAuth, ok := d.zones[domain.Domain(qname)]; ok {
return nonAuth, true
}
// Move to parent domain
idx := strings.Index(qname, ".")
if idx == -1 || idx == len(qname)-1 {
return false, false
}
qname = qname[idx+1:]
}
}
// shouldFallthrough checks if the query should fallthrough to the next handler.
// Returns true if the queried name belongs to a non-authoritative zone.
func (d *Resolver) shouldFallthrough(qname string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
nonAuth, found := d.findZone(qname)
return found && nonAuth
}
func (d *Resolver) continueToNext(logger *log.Entry, w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Warnf("failed to write continue signal: %v", err)
}
}
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
d.mu.RLock()
@@ -137,14 +180,8 @@ func (d *Resolver) isInManagedZone(name string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
name = dns.Fqdn(name)
for _, zone := range d.zones {
zoneStr := dns.Fqdn(zone.PunycodeString())
if strings.EqualFold(name, zoneStr) || strings.HasSuffix(strings.ToLower(name), strings.ToLower("."+zoneStr)) {
return true
}
}
return false
_, found := d.findZone(name)
return found
}
// lookupResult contains the result of a DNS lookup operation.
@@ -343,21 +380,23 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
}
}
// Update updates the resolver with new records and zone information.
// The zones parameter specifies which DNS zones this resolver manages.
func (d *Resolver) Update(update []nbdns.SimpleRecord, zones []domain.Domain) {
// Update replaces all zones and their records
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
maps.Clear(d.domains)
maps.Clear(d.zones)
d.zones = zones
for _, zone := range customZones {
zoneDomain := domain.Domain(strings.ToLower(dns.Fqdn(zone.Domain)))
d.zones[zoneDomain] = zone.NonAuthoritative
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
for _, rec := range zone.Records {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
}
}
}
}

View File

@@ -16,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
)
// mockResolver implements resolver for testing
@@ -125,11 +124,11 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
zone1 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1}}}
zone2 := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record2}}}
// Apply first update
resolver.Update(update1, nil)
resolver.Update(zone1)
// Verify first update
resolver.mu.RLock()
@@ -141,7 +140,7 @@ func TestLocalResolver_Update_StaleRecord(t *testing.T) {
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2, nil)
resolver.Update(zone2)
// Verify second update
resolver.mu.RLock()
@@ -170,10 +169,10 @@ func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2}}}
// Apply update with both records
resolver.Update(update, nil)
resolver.Update(zones)
// Create question that matches both records
question := dns.Question{
@@ -214,10 +213,10 @@ func TestLocalResolver_RecordRotation(t *testing.T) {
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
zones := []nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{record1, record2, record3}}}
// Apply update with all three records
resolver.Update(update, nil)
resolver.Update(zones)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
@@ -283,7 +282,7 @@ func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}, nil)
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord}}})
testCases := []struct {
name string
@@ -398,7 +397,7 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord}, nil)
resolver.Update([]nbdns.CustomZone{{Domain: "example.com.", Records: []nbdns.SimpleRecord{cnameRecord, targetRecord}}})
testCases := []struct {
name string
@@ -526,7 +525,7 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
RData: "target.example.com.",
}
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME}, nil)
resolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud.", Records: []nbdns.SimpleRecord{recordA, recordCNAME}}})
testCases := []struct {
name string
@@ -620,10 +619,13 @@ func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("simple internal CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "alias.example.com.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.example.com."},
{Name: "target.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.example.com.", dns.TypeA)
var resp *dns.Msg
@@ -644,11 +646,14 @@ func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("multi-hop CNAME chain", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "hop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop2.test."},
{Name: "hop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "hop3.test."},
{Name: "hop3.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
@@ -661,9 +666,12 @@ func TestLocalResolver_CNAMEChainResolution(t *testing.T) {
t.Run("CNAME to non-existent internal target returns only CNAME", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
@@ -695,7 +703,7 @@ func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
Name: "hop8.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update(records, nil)
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("hop1.test.", dns.TypeA)
var resp *dns.Msg
@@ -723,7 +731,7 @@ func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
Name: "deep11.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.10.10.10",
})
resolver.Update(records, nil)
resolver.Update([]nbdns.CustomZone{{Domain: "test.", Records: records}})
msg := new(dns.Msg).SetQuestion("deep1.test.", dns.TypeA)
var resp *dns.Msg
@@ -736,10 +744,13 @@ func TestLocalResolver_CNAMEMaxDepth(t *testing.T) {
t.Run("circular CNAME is protected by max depth", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "loop1.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop2.test."},
{Name: "loop2.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "loop1.test."},
},
}})
msg := new(dns.Msg).SetQuestion("loop1.test.", dns.TypeA)
var resp *dns.Msg
@@ -763,9 +774,12 @@ func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
@@ -794,9 +808,12 @@ func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
@@ -825,9 +842,12 @@ func TestLocalResolver_ExternalCNAMEResolution(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "concurrent.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
var wg sync.WaitGroup
results := make([]*dns.Msg, 10)
@@ -856,10 +876,12 @@ func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update sets zones correctly", func(t *testing.T) {
resolver := NewResolver()
zones := []domain.Domain{"example.com", "test.local"}
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, zones)
resolver.Update([]nbdns.CustomZone{
{Domain: "example.com.", Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}},
{Domain: "test.local."},
})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("other.example.com."))
@@ -869,7 +891,7 @@ func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("isInManagedZone case insensitive", func(t *testing.T) {
resolver := NewResolver()
resolver.Update(nil, []domain.Domain{"Example.COM"})
resolver.Update([]nbdns.CustomZone{{Domain: "Example.COM."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
assert.True(t, resolver.isInManagedZone("HOST.EXAMPLE.COM."))
@@ -877,10 +899,10 @@ func TestLocalResolver_ZoneManagement(t *testing.T) {
t.Run("Update clears zones", func(t *testing.T) {
resolver := NewResolver()
resolver.Update(nil, []domain.Domain{"example.com"})
resolver.Update([]nbdns.CustomZone{{Domain: "example.com."}})
assert.True(t, resolver.isInManagedZone("host.example.com."))
resolver.Update(nil, nil)
resolver.Update(nil)
assert.False(t, resolver.isInManagedZone("host.example.com."))
})
}
@@ -889,9 +911,12 @@ func TestLocalResolver_ZoneManagement(t *testing.T) {
func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target in managed zone returns NXDOMAIN per RFC 6604", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
}, []domain.Domain{"myzone.test"})
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "nonexistent.myzone.test."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
@@ -913,9 +938,12 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
}, []domain.Domain{"myzone.test"})
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.other.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeA)
var resp *dns.Msg
@@ -929,10 +957,13 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
t.Run("CNAME target exists with different type returns NODATA not NXDOMAIN", func(t *testing.T) {
resolver := NewResolver()
// CNAME points to target that has A but no AAAA - query for AAAA should be NODATA
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
}, []domain.Domain{"myzone.test"})
resolver.Update([]nbdns.CustomZone{{
Domain: "myzone.test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.myzone.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "target.myzone.test."},
{Name: "target.myzone.test.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1"},
},
}})
msg := new(dns.Msg).SetQuestion("alias.myzone.test.", dns.TypeAAAA)
var resp *dns.Msg
@@ -963,9 +994,12 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeAAAA)
var resp *dns.Msg
@@ -1035,9 +1069,12 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
resolver := NewResolver()
resolver.resolver = &mockResolver{lookupFunc: tc.lookupFunc}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
@@ -1054,13 +1091,112 @@ func TestLocalResolver_CNAMEZoneAwareResolution(t *testing.T) {
}
}
// TestLocalResolver_Fallthrough verifies that non-authoritative zones
// trigger fallthrough (Zero bit set) when no records match
func TestLocalResolver_Fallthrough(t *testing.T) {
resolver := NewResolver()
record := nbdns.SimpleRecord{
Name: "existing.custom.zone.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}
testCases := []struct {
name string
zones []nbdns.CustomZone
queryName string
expectFallthrough bool
expectRecord bool
}{
{
name: "Authoritative zone returns NXDOMAIN without fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: false,
expectRecord: false,
},
{
name: "Non-authoritative zone triggers fallthrough",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "nonexistent.custom.zone.",
expectFallthrough: true,
expectRecord: false,
},
{
name: "Record found in non-authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
NonAuthoritative: true,
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
{
name: "Record found in authoritative zone returns normally",
zones: []nbdns.CustomZone{{
Domain: "custom.zone.",
Records: []nbdns.SimpleRecord{record},
}},
queryName: "existing.custom.zone.",
expectFallthrough: false,
expectRecord: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resolver.Update(tc.zones)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response")
if tc.expectFallthrough {
assert.True(t, responseMSG.MsgHdr.Zero, "Zero bit should be set for fallthrough")
assert.Equal(t, dns.RcodeNameError, responseMSG.Rcode, "Should return NXDOMAIN")
} else {
assert.False(t, responseMSG.MsgHdr.Zero, "Zero bit should not be set")
}
if tc.expectRecord {
assert.Greater(t, len(responseMSG.Answer), 0, "Should have answer records")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
}
})
}
}
// TestLocalResolver_AuthoritativeFlag tests the AA flag behavior
func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
t.Run("direct record lookup is authoritative", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, []domain.Domain{"example.com"})
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
msg := new(dns.Msg).SetQuestion("host.example.com.", dns.TypeA)
var resp *dns.Msg
@@ -1081,9 +1217,12 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
msg := new(dns.Msg).SetQuestion("alias.test.", dns.TypeA)
var resp *dns.Msg
@@ -1099,9 +1238,12 @@ func TestLocalResolver_AuthoritativeFlag(t *testing.T) {
func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop clears all state", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, []domain.Domain{"example.com"})
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
@@ -1116,9 +1258,12 @@ func TestLocalResolver_Stop(t *testing.T) {
t.Run("Stop is safe to call multiple times", func(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
}, []domain.Domain{"example.com"})
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"},
},
}})
resolver.Stop()
resolver.Stop()
@@ -1140,9 +1285,12 @@ func TestLocalResolver_Stop(t *testing.T) {
},
}
resolver.Update([]nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
}, nil)
resolver.Update([]nbdns.CustomZone{{
Domain: "test.",
Records: []nbdns.SimpleRecord{
{Name: "alias.test.", Type: int(dns.TypeCNAME), Class: nbdns.DefaultClass, TTL: 300, RData: "external.example.com."},
},
}})
done := make(chan struct{})
go func() {
@@ -1167,3 +1315,107 @@ func TestLocalResolver_Stop(t *testing.T) {
}
})
}
// TestLocalResolver_FallthroughCaseInsensitive verifies case-insensitive domain matching for fallthrough
func TestLocalResolver_FallthroughCaseInsensitive(t *testing.T) {
resolver := NewResolver()
resolver.Update([]nbdns.CustomZone{{
Domain: "EXAMPLE.COM.",
Records: []nbdns.SimpleRecord{{Name: "host.example.com.", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "1.2.3.4"}},
NonAuthoritative: true,
}})
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
msg := new(dns.Msg).SetQuestion("nonexistent.example.com.", dns.TypeA)
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG)
assert.True(t, responseMSG.MsgHdr.Zero, "Should fallthrough for non-authoritative zone with case-insensitive match")
}
// BenchmarkFindZone_BestCase benchmarks zone lookup with immediate match (first label)
func BenchmarkFindZone_BestCase(b *testing.B) {
resolver := NewResolver()
// Single zone that matches immediately
resolver.Update([]nbdns.CustomZone{{
Domain: "example.com.",
NonAuthoritative: true,
}})
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough("example.com.")
}
}
// BenchmarkFindZone_WorstCase benchmarks zone lookup with many zones, no match, many labels
func BenchmarkFindZone_WorstCase(b *testing.B) {
resolver := NewResolver()
// 100 zones that won't match
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
NonAuthoritative: true,
})
}
resolver.Update(zones)
// Query with many labels that won't match any zone
qname := "a.b.c.d.e.f.g.h.external.com."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkFindZone_TypicalCase benchmarks typical usage: few zones, subdomain match
func BenchmarkFindZone_TypicalCase(b *testing.B) {
resolver := NewResolver()
// Typical setup: peer zone (authoritative) + one user zone (non-authoritative)
resolver.Update([]nbdns.CustomZone{
{Domain: "netbird.cloud.", NonAuthoritative: false},
{Domain: "custom.local.", NonAuthoritative: true},
})
// Query for subdomain of user zone
qname := "myhost.custom.local."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.shouldFallthrough(qname)
}
}
// BenchmarkIsInManagedZone_ManyZones benchmarks isInManagedZone with 100 zones
func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
resolver := NewResolver()
var zones []nbdns.CustomZone
for i := 0; i < 100; i++ {
zones = append(zones, nbdns.CustomZone{
Domain: fmt.Sprintf("zone%d.internal.", i),
})
}
resolver.Update(zones)
// Query that matches zone50
qname := "host.zone50.internal."
b.ResetTimer()
for i := 0; i < b.N; i++ {
resolver.isInManagedZone(qname)
}
}

View File

@@ -485,7 +485,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
}
}
localMuxUpdates, localRecords, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localZones, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("local handler updater: %w", err)
}
@@ -498,8 +498,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.updateMux(muxUpdates)
// register local records
s.localResolver.Update(localRecords, localZones)
s.localResolver.Update(localZones)
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
@@ -659,10 +658,9 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityFallback)
}
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.SimpleRecord, []domain.Domain, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, []nbdns.CustomZone, error) {
var muxUpdates []handlerWrapper
var localRecords []nbdns.SimpleRecord
var zones []domain.Domain
var zones []nbdns.CustomZone
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
@@ -676,19 +674,20 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityLocal,
})
zones = append(zones, domain.Domain(customZone.Domain))
// zone records contain the fqdn, so we can just flatten them
var localRecords []nbdns.SimpleRecord
for _, record := range customZone.Records {
if record.Class != nbdns.DefaultClass {
log.Warnf("received an invalid class type: %s", record.Class)
continue
}
// zone records contain the fqdn, so we can just flatten them
localRecords = append(localRecords, record)
}
customZone.Records = localRecords
zones = append(zones, customZone)
}
return muxUpdates, localRecords, zones, nil
return muxUpdates, zones, nil
}
func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.NameServerGroup) ([]handlerWrapper, error) {

View File

@@ -128,7 +128,7 @@ func TestUpdateDNSServer(t *testing.T) {
testCases := []struct {
name string
initUpstreamMap registeredHandlerMap
initLocalRecords []nbdns.SimpleRecord
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
@@ -181,7 +181,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "New Config Should Succeed",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud",
@@ -222,7 +222,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 2,
inputSerial: 1,
@@ -230,7 +230,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -252,7 +252,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -274,7 +274,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalRecords: []nbdns.SimpleRecord{},
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap),
initSerial: 0,
inputSerial: 1,
@@ -300,7 +300,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -316,7 +316,7 @@ func TestUpdateDNSServer(t *testing.T) {
},
{
name: "Disabled Service Should clean map",
initLocalRecords: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}},
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name,
@@ -385,7 +385,7 @@ func TestUpdateDNSServer(t *testing.T) {
}()
dnsServer.dnsMuxMap = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalRecords, nil)
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
@@ -510,8 +510,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
priority: PriorityUpstream,
},
}
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
dnsServer.localResolver.Update([]nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}, nil)
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
@@ -2013,7 +2012,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
},
}
localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
@@ -2074,7 +2073,7 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
},
}
localMuxUpdates, _, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")

View File

@@ -202,6 +202,10 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
resutil.SetMeta(w, "upstream", upstream.String())
// Clear Zero bit from external responses to prevent upstream servers from
// manipulating our internal fallthrough signaling mechanism
rm.MsgHdr.Zero = false
if err := w.WriteMsg(rm); err != nil {
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
return true

View File

@@ -1251,11 +1251,16 @@ func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns
ForwarderPort: forwarderPort,
}
for _, zone := range protoDNSConfig.GetCustomZones() {
protoZones := protoDNSConfig.GetCustomZones()
// Treat single zone as authoritative for backward compatibility with old servers
// that only send the peer FQDN zone without setting field 4.
singleZoneCompat := len(protoZones) == 1
for _, zone := range protoZones {
dnsZone := nbdns.CustomZone{
Domain: zone.GetDomain(),
SearchDomainDisabled: zone.GetSearchDomainDisabled(),
SkipPTRProcess: zone.GetSkipPTRProcess(),
NonAuthoritative: zone.GetNonAuthoritative() && !singleZoneCompat,
}
for _, record := range zone.Records {
dnsRecord := nbdns.SimpleRecord{

View File

@@ -325,6 +325,10 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
return fmt.Errorf("received nil DNS message")
}
// Clear Zero bit from peer responses to prevent external sources from
// manipulating our internal fallthrough signaling mechanism
r.MsgHdr.Zero = false
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {

View File

@@ -47,8 +47,8 @@ type CustomZone struct {
Records []SimpleRecord
// SearchDomainDisabled indicates whether to add match domains to a search domains list or not
SearchDomainDisabled bool
// SkipPTRProcess indicates whether a client should process PTR records from custom zones
SkipPTRProcess bool
// NonAuthoritative marks user-created zones
NonAuthoritative bool
}
// SimpleRecord provides a simple DNS record specification for CNAME, A and AAAA records

View File

@@ -374,8 +374,9 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{

View File

@@ -2873,7 +2873,7 @@ type CustomZone struct {
Domain string `protobuf:"bytes,1,opt,name=Domain,proto3" json:"Domain,omitempty"`
Records []*SimpleRecord `protobuf:"bytes,2,rep,name=Records,proto3" json:"Records,omitempty"`
SearchDomainDisabled bool `protobuf:"varint,3,opt,name=SearchDomainDisabled,proto3" json:"SearchDomainDisabled,omitempty"`
SkipPTRProcess bool `protobuf:"varint,4,opt,name=SkipPTRProcess,proto3" json:"SkipPTRProcess,omitempty"`
NonAuthoritative bool `protobuf:"varint,4,opt,name=NonAuthoritative,proto3" json:"NonAuthoritative,omitempty"`
}
func (x *CustomZone) Reset() {
@@ -2929,9 +2929,9 @@ func (x *CustomZone) GetSearchDomainDisabled() bool {
return false
}
func (x *CustomZone) GetSkipPTRProcess() bool {
func (x *CustomZone) GetNonAuthoritative() bool {
if x != nil {
return x.SkipPTRProcess
return x.NonAuthoritative
}
return false
}

View File

@@ -464,7 +464,9 @@ message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
bool SearchDomainDisabled = 3;
bool SkipPTRProcess = 4;
// NonAuthoritative indicates this is a user-created zone (not the built-in peer DNS zone).
// Non-authoritative zones will fallthrough to lower-priority handlers on NXDOMAIN and skip PTR processing.
bool NonAuthoritative = 4;
}
// SimpleRecord represents a dns.SimpleRecord