mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-05 09:03:54 -04:00
487 lines
13 KiB
Go
487 lines
13 KiB
Go
package dnsfwd
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net"
|
|
"net/netip"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-multierror"
|
|
"github.com/miekg/dns"
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
|
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
"github.com/netbirdio/netbird/client/internal/peer"
|
|
"github.com/netbirdio/netbird/route"
|
|
)
|
|
|
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
|
const upstreamTimeout = 15 * time.Second
|
|
|
|
type resolver interface {
|
|
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
|
}
|
|
|
|
type firewaller interface {
|
|
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
|
|
}
|
|
|
|
type DNSForwarder struct {
|
|
listenAddress netip.AddrPort
|
|
ttl uint32
|
|
statusRecorder *peer.Status
|
|
|
|
dnsServer *dns.Server
|
|
mux *dns.ServeMux
|
|
tcpServer *dns.Server
|
|
tcpMux *dns.ServeMux
|
|
|
|
mutex sync.RWMutex
|
|
fwdEntries []*ForwarderEntry
|
|
firewall firewaller
|
|
resolver resolver
|
|
cache *cache
|
|
|
|
wgIface wgIface
|
|
}
|
|
|
|
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
|
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
|
return &DNSForwarder{
|
|
listenAddress: listenAddress,
|
|
ttl: ttl,
|
|
firewall: firewall,
|
|
statusRecorder: statusRecorder,
|
|
resolver: net.DefaultResolver,
|
|
cache: newCache(),
|
|
wgIface: wgIface,
|
|
}
|
|
}
|
|
|
|
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
|
var netstackNet *netstack.Net
|
|
if f.wgIface != nil {
|
|
netstackNet = f.wgIface.GetNet()
|
|
}
|
|
|
|
addrDesc := f.listenAddress.String()
|
|
if netstackNet != nil {
|
|
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
|
|
}
|
|
log.Infof("starting DNS forwarder on address=%s", addrDesc)
|
|
|
|
udpLn, err := f.createUDPListener(netstackNet)
|
|
if err != nil {
|
|
return fmt.Errorf("create UDP listener: %w", err)
|
|
}
|
|
|
|
tcpLn, err := f.createTCPListener(netstackNet)
|
|
if err != nil {
|
|
return fmt.Errorf("create TCP listener: %w", err)
|
|
}
|
|
|
|
mux := dns.NewServeMux()
|
|
f.mux = mux
|
|
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
|
f.dnsServer = &dns.Server{
|
|
PacketConn: udpLn,
|
|
Handler: mux,
|
|
}
|
|
|
|
tcpMux := dns.NewServeMux()
|
|
f.tcpMux = tcpMux
|
|
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
|
f.tcpServer = &dns.Server{
|
|
Listener: tcpLn,
|
|
Handler: tcpMux,
|
|
}
|
|
|
|
f.UpdateDomains(entries)
|
|
|
|
errCh := make(chan error, 2)
|
|
|
|
go func() {
|
|
log.Infof("DNS UDP listener running on %s", addrDesc)
|
|
errCh <- f.dnsServer.ActivateAndServe()
|
|
}()
|
|
go func() {
|
|
log.Infof("DNS TCP listener running on %s", addrDesc)
|
|
errCh <- f.tcpServer.ActivateAndServe()
|
|
}()
|
|
|
|
return <-errCh
|
|
}
|
|
|
|
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
|
|
if netstackNet != nil {
|
|
return netstackNet.ListenUDPAddrPort(f.listenAddress)
|
|
}
|
|
|
|
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
|
|
}
|
|
|
|
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
|
|
if netstackNet != nil {
|
|
return netstackNet.ListenTCPAddrPort(f.listenAddress)
|
|
}
|
|
|
|
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
|
|
}
|
|
|
|
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
|
f.mutex.Lock()
|
|
defer f.mutex.Unlock()
|
|
|
|
// remove cache entries for domains that no longer appear
|
|
f.removeStaleCacheEntries(f.fwdEntries, entries)
|
|
|
|
f.fwdEntries = entries
|
|
log.Debugf("Updated DNS forwarder with %d domains", len(entries))
|
|
}
|
|
|
|
// removeStaleCacheEntries unsets cache items for domains that were present
|
|
// in the old list but not present in the new list.
|
|
func (f *DNSForwarder) removeStaleCacheEntries(oldEntries, newEntries []*ForwarderEntry) {
|
|
if f.cache == nil {
|
|
return
|
|
}
|
|
|
|
newSet := make(map[string]struct{}, len(newEntries))
|
|
for _, e := range newEntries {
|
|
if e == nil {
|
|
continue
|
|
}
|
|
newSet[e.Domain.PunycodeString()] = struct{}{}
|
|
}
|
|
|
|
for _, e := range oldEntries {
|
|
if e == nil {
|
|
continue
|
|
}
|
|
pattern := e.Domain.PunycodeString()
|
|
if _, ok := newSet[pattern]; !ok {
|
|
f.cache.unset(pattern)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (f *DNSForwarder) Close(ctx context.Context) error {
|
|
var result *multierror.Error
|
|
|
|
if f.dnsServer != nil {
|
|
if err := f.dnsServer.ShutdownContext(ctx); err != nil {
|
|
result = multierror.Append(result, fmt.Errorf("UDP shutdown: %w", err))
|
|
}
|
|
}
|
|
if f.tcpServer != nil {
|
|
if err := f.tcpServer.ShutdownContext(ctx); err != nil {
|
|
result = multierror.Append(result, fmt.Errorf("TCP shutdown: %w", err))
|
|
}
|
|
}
|
|
|
|
return nberrors.FormatErrorOrNil(result)
|
|
}
|
|
|
|
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
|
if len(query.Question) == 0 {
|
|
return nil
|
|
}
|
|
question := query.Question[0]
|
|
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
|
|
question.Name, question.Qtype, question.Qclass)
|
|
|
|
domain := strings.ToLower(question.Name)
|
|
|
|
resp := query.SetReply(query)
|
|
var network string
|
|
switch question.Qtype {
|
|
case dns.TypeA:
|
|
network = "ip4"
|
|
case dns.TypeAAAA:
|
|
network = "ip6"
|
|
default:
|
|
// TODO: Handle other types
|
|
|
|
resp.Rcode = dns.RcodeNotImplemented
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write DNS response: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
|
// query doesn't match any configured domain
|
|
if mostSpecificResId == "" {
|
|
resp.Rcode = dns.RcodeRefused
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write DNS response: %v", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
|
defer cancel()
|
|
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
|
if err != nil {
|
|
f.handleDNSError(ctx, w, question, resp, domain, err)
|
|
return nil
|
|
}
|
|
|
|
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
|
|
f.addIPsToResponse(resp, domain, ips)
|
|
f.cache.set(domain, question.Qtype, ips)
|
|
|
|
return resp
|
|
}
|
|
|
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|
resp := f.handleDNSQuery(w, query)
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
opt := query.IsEdns0()
|
|
maxSize := dns.MinMsgSize
|
|
if opt != nil {
|
|
// client advertised a larger EDNS0 buffer
|
|
maxSize = int(opt.UDPSize())
|
|
}
|
|
|
|
// if our response is too big, truncate and set the TC bit
|
|
if resp.Len() > maxSize {
|
|
resp.Truncate(maxSize)
|
|
}
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write DNS response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|
resp := f.handleDNSQuery(w, query)
|
|
if resp == nil {
|
|
return
|
|
}
|
|
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
log.Errorf("failed to write DNS response: %v", err)
|
|
}
|
|
}
|
|
|
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
|
var prefixes []netip.Prefix
|
|
if mostSpecificResId != "" {
|
|
for _, ip := range ips {
|
|
var prefix netip.Prefix
|
|
if ip.Is4() {
|
|
prefix = netip.PrefixFrom(ip, 32)
|
|
} else {
|
|
prefix = netip.PrefixFrom(ip, 128)
|
|
}
|
|
prefixes = append(prefixes, prefix)
|
|
f.statusRecorder.AddResolvedIPLookupEntry(prefix, mostSpecificResId)
|
|
}
|
|
}
|
|
|
|
if f.firewall != nil {
|
|
f.updateFirewall(matchingEntries, prefixes)
|
|
}
|
|
}
|
|
|
|
func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixes []netip.Prefix) {
|
|
var merr *multierror.Error
|
|
for _, entry := range matchingEntries {
|
|
if err := f.firewall.UpdateSet(entry.Set, prefixes); err != nil {
|
|
merr = multierror.Append(merr, fmt.Errorf("update set for domain=%s: %w", entry.Domain, err))
|
|
}
|
|
}
|
|
if merr != nil {
|
|
log.Errorf("failed to update firewall sets (%d/%d): %v",
|
|
len(merr.Errors),
|
|
len(matchingEntries),
|
|
nberrors.FormatErrorOrNil(merr))
|
|
}
|
|
}
|
|
|
|
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
|
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
|
//
|
|
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
|
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
|
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
|
// only handles A/AAAA queries and returns NOTIMP for other types.
|
|
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
|
// Try querying for a different record type to see if the domain exists
|
|
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
|
// This helps distinguish between NXDOMAIN and NODATA.
|
|
var alternativeNetwork string
|
|
switch originalQtype {
|
|
case dns.TypeAAAA:
|
|
alternativeNetwork = "ip4"
|
|
case dns.TypeA:
|
|
alternativeNetwork = "ip6"
|
|
default:
|
|
resp.Rcode = dns.RcodeNameError
|
|
return
|
|
}
|
|
|
|
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
|
var dnsErr *net.DNSError
|
|
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
|
// Alternative query also returned not found - domain truly doesn't exist
|
|
resp.Rcode = dns.RcodeNameError
|
|
return
|
|
}
|
|
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
|
resp.Rcode = dns.RcodeSuccess
|
|
return
|
|
}
|
|
|
|
// Alternative query succeeded - domain exists but has no records of this type
|
|
resp.Rcode = dns.RcodeSuccess
|
|
}
|
|
|
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response.
|
|
func (f *DNSForwarder) handleDNSError(
|
|
ctx context.Context,
|
|
w dns.ResponseWriter,
|
|
question dns.Question,
|
|
resp *dns.Msg,
|
|
domain string,
|
|
err error,
|
|
) {
|
|
// Default to SERVFAIL; override below when appropriate.
|
|
resp.Rcode = dns.RcodeServerFailure
|
|
|
|
qType := question.Qtype
|
|
qTypeName := dns.TypeToString[qType]
|
|
|
|
// Prefer typed DNS errors; fall back to generic logging otherwise.
|
|
var dnsErr *net.DNSError
|
|
if !errors.As(err, &dnsErr) {
|
|
log.Warnf(errResolveFailed, domain, err)
|
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
}
|
|
return
|
|
}
|
|
|
|
// NotFound: set NXDOMAIN / appropriate code via helper.
|
|
if dnsErr.IsNotFound {
|
|
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
}
|
|
f.cache.set(domain, question.Qtype, nil)
|
|
return
|
|
}
|
|
|
|
// Upstream failed but we might have a cached answer—serve it if present.
|
|
if ips, ok := f.cache.get(domain, qType); ok {
|
|
if len(ips) > 0 {
|
|
log.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
|
f.addIPsToResponse(resp, domain, ips)
|
|
resp.Rcode = dns.RcodeSuccess
|
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
log.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
}
|
|
} else { // send NXDOMAIN / appropriate code if cache is empty
|
|
f.setResponseCodeForNotFound(ctx, resp, domain, qType)
|
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// No cache. Log with or without the server field for more context.
|
|
if dnsErr.Server != "" {
|
|
log.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, err)
|
|
} else {
|
|
log.Warnf(errResolveFailed, domain, err)
|
|
}
|
|
|
|
// Write final failure response.
|
|
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
log.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
}
|
|
}
|
|
|
|
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
|
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
|
for _, ip := range ips {
|
|
var respRecord dns.RR
|
|
if ip.Is6() {
|
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
|
rr := dns.AAAA{
|
|
AAAA: ip.AsSlice(),
|
|
Hdr: dns.RR_Header{
|
|
Name: domain,
|
|
Rrtype: dns.TypeAAAA,
|
|
Class: dns.ClassINET,
|
|
Ttl: f.ttl,
|
|
},
|
|
}
|
|
respRecord = &rr
|
|
} else {
|
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
|
rr := dns.A{
|
|
A: ip.AsSlice(),
|
|
Hdr: dns.RR_Header{
|
|
Name: domain,
|
|
Rrtype: dns.TypeA,
|
|
Class: dns.ClassINET,
|
|
Ttl: f.ttl,
|
|
},
|
|
}
|
|
respRecord = &rr
|
|
}
|
|
resp.Answer = append(resp.Answer, respRecord)
|
|
}
|
|
}
|
|
|
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
|
// It returns the most specific match and all matching resource IDs.
|
|
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {
|
|
var selectedResId route.ResID
|
|
var bestScore int
|
|
var matches []*ForwarderEntry
|
|
|
|
f.mutex.RLock()
|
|
defer f.mutex.RUnlock()
|
|
|
|
for _, entry := range f.fwdEntries {
|
|
var score int
|
|
pattern := entry.Domain.PunycodeString()
|
|
|
|
switch {
|
|
case strings.HasPrefix(pattern, "*."):
|
|
baseDomain := strings.TrimPrefix(pattern, "*.")
|
|
|
|
if strings.EqualFold(domain, baseDomain) || strings.HasSuffix(domain, "."+baseDomain) {
|
|
score = len(baseDomain)
|
|
matches = append(matches, entry)
|
|
}
|
|
case domain == pattern:
|
|
score = math.MaxInt
|
|
matches = append(matches, entry)
|
|
default:
|
|
continue
|
|
}
|
|
|
|
if score > bestScore {
|
|
bestScore = score
|
|
selectedResId = entry.ResID
|
|
}
|
|
}
|
|
|
|
return selectedResId, matches
|
|
}
|