fix(update): fetch IPv6 AAAA records and not only IPv4

This commit is contained in:
Quentin McGaw
2024-11-21 22:25:14 +00:00
parent 75191c2876
commit e95816ab46

View File

@@ -3,8 +3,10 @@ package update
import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"time"
"github.com/qdm12/ddns-updater/internal/constants"
@@ -34,7 +36,8 @@ type Service struct {
func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
period time.Duration, cooldown time.Duration, logger Logger, resolver LookupIPer,
timeNow func() time.Time, hioClient HealthchecksIOClient) *Service {
timeNow func() time.Time, hioClient HealthchecksIOClient,
) *Service {
return &Service{
period: period,
db: db,
@@ -50,37 +53,59 @@ func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
}
}
func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) (
ipv4, ipv6 []netip.Addr, err error) {
for i := 0; i < tries; i++ {
ipv4, ipv6, err = s.lookupIPs(ctx, hostname)
if err == nil {
return ipv4, ipv6, nil
func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries uint) (
ipv4, ipv6 []netip.Addr, err error,
) {
type result struct {
network string
ips []net.IP
err error
}
results := make(chan result)
networks := []string{"ip4", "ip6"}
lookupCtx, cancel := context.WithCancel(ctx)
for _, network := range networks {
go func(ctx context.Context, network string, results chan<- result) {
for range tries {
ips, err := s.resolver.LookupIP(ctx, network, hostname)
if err != nil {
if strings.HasSuffix(err.Error(), "no such host") {
results <- result{network: network} // no IP address for this network
return
}
continue // retry
}
results <- result{network: network, ips: ips, err: err}
return
}
}(lookupCtx, network, results)
}
for range networks {
result := <-results
if result.err != nil {
if err == nil {
cancel()
err = fmt.Errorf("looking up %s addresses: %w", result.network, result.err)
}
continue
}
switch result.network {
case "ip4":
ipv4 = make([]netip.Addr, len(result.ips))
for i, ip := range result.ips {
ipv4[i] = netip.AddrFrom4([4]byte(ip))
}
case "ip6":
ipv6 = make([]netip.Addr, len(result.ips))
for i, ip := range result.ips {
ipv6[i] = netip.AddrFrom16([16]byte(ip))
}
}
}
return nil, nil, err
}
cancel()
func (s *Service) lookupIPs(ctx context.Context, hostname string) (
ipv4, ipv6 []netip.Addr, err error) {
netIPs, err := s.resolver.LookupIP(ctx, "ip", hostname)
if err != nil {
return nil, nil, err
}
ipv4 = make([]netip.Addr, 0, len(netIPs))
ipv6 = make([]netip.Addr, 0, len(netIPs))
for _, netIP := range netIPs {
switch {
case netIP == nil:
case netIP.To4() != nil:
ipv4 = append(ipv4, netip.AddrFrom4([4]byte(netIP.To4())))
default: // IPv6
ipv6 = append(ipv6, netip.AddrFrom16([16]byte(netIP.To16())))
}
}
return ipv4, ipv6, nil
return ipv4, ipv6, err
}
func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
@@ -101,7 +126,8 @@ func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
}
func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
ip, ipv4, ipv6 netip.Addr, errors []error) {
ip, ipv4, ipv6 netip.Addr, errors []error,
) {
var err error
if doIP {
ip, err = tryAndRepeatGettingIP(ctx, s.ipGetter.IP, s.logger, ipversion.IP4or6)
@@ -125,7 +151,8 @@ func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
}
func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords.Record,
ip, ipv4, ipv6 netip.Addr) (recordIDs map[uint]struct{}) {
ip, ipv4, ipv6 netip.Addr,
) (recordIDs map[uint]struct{}) {
recordIDs = make(map[uint]struct{})
for i, record := range records {
shouldUpdate := s.shouldUpdateRecord(ctx, record, ip, ipv4, ipv6)
@@ -138,7 +165,8 @@ func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords
}
func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Record,
ip, ipv4, ipv6 netip.Addr) (update bool) {
ip, ipv4, ipv6 netip.Addr,
) (update bool) {
now := s.timeNow()
isWithinCooldown := now.Sub(record.History.GetSuccessTime()) < s.cooldown
@@ -178,7 +206,8 @@ func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Reco
}
func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversion.IPVersion,
lastIP, publicIP netip.Addr) (update bool) {
lastIP, publicIP netip.Addr,
) (update bool) {
ipKind := ipVersionToIPKind(ipVersion)
if publicIP.IsValid() && publicIP.Compare(lastIP) != 0 {
s.logInfoNoLookupUpdate(hostname, ipKind, lastIP, publicIP)
@@ -189,7 +218,8 @@ func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversio
}
func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname string,
ipVersion ipversion.IPVersion, publicIP netip.Addr) (update bool) {
ipVersion ipversion.IPVersion, publicIP netip.Addr,
) (update bool) {
const tries = 5
recordIPv4s, recordIPv6s, err := s.lookupIPsResilient(ctx, hostname, tries)
if err != nil {
@@ -375,7 +405,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, startErr er
}
func (s *Service) run(ctx context.Context, ready chan<- struct{},
done chan<- struct{}) {
done chan<- struct{},
) {
defer close(done)
ticker := time.NewTicker(s.period)
close(ready)