[client] Fix dns forwarder handling of requested record types (#3615)

This commit is contained in:
Viktor Liu
2025-04-03 13:58:36 +02:00
committed by GitHub
parent 09243a0fe0
commit 80702b9323

View File

@@ -4,6 +4,8 @@ import (
"context"
"errors"
"net"
"net/netip"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
@@ -12,6 +14,7 @@ import (
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second
type DNSForwarder struct {
listenAddress string
@@ -79,41 +82,72 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
domain := question.Name
resp := query.SetReply(query)
var network string
switch question.Qtype {
case dns.TypeA:
network = "ip4"
case dns.TypeAAAA:
network = "ip6"
default:
// TODO: Handle other types
ips, err := net.LookupIP(domain)
if err != nil {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
log.Errorf("failed to write DNS response: %v", err)
}
return
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(w, resp, domain, err)
return
}
f.addIPsToResponse(resp, domain, ips)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
for _, ip := range ips {
var respRecord dns.RR
if ip.To4() == nil {
if ip.Is6() {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
AAAA: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
@@ -125,7 +159,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip,
A: ip.AsSlice(),
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
@@ -137,10 +171,6 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
}
resp.Answer = append(resp.Answer, respRecord)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
// filterDomains returns a list of normalized domains