fix(update): fetch IPv6 AAAA records and not only IPv4

This commit is contained in:
Quentin McGaw
2024-11-21 22:25:14 +00:00
parent 75191c2876
commit e95816ab46

View File

@@ -3,8 +3,10 @@ package update
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/qdm12/ddns-updater/internal/constants" "github.com/qdm12/ddns-updater/internal/constants"
@@ -34,7 +36,8 @@ type Service struct {
func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher, func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
period time.Duration, cooldown time.Duration, logger Logger, resolver LookupIPer, period time.Duration, cooldown time.Duration, logger Logger, resolver LookupIPer,
timeNow func() time.Time, hioClient HealthchecksIOClient) *Service { timeNow func() time.Time, hioClient HealthchecksIOClient,
) *Service {
return &Service{ return &Service{
period: period, period: period,
db: db, db: db,
@@ -50,37 +53,59 @@ func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
} }
} }
func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) ( func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries uint) (
ipv4, ipv6 []netip.Addr, err error) { ipv4, ipv6 []netip.Addr, err error,
for i := 0; i < tries; i++ { ) {
ipv4, ipv6, err = s.lookupIPs(ctx, hostname) type result struct {
if err == nil { network string
return ipv4, ipv6, nil ips []net.IP
err error
}
results := make(chan result)
networks := []string{"ip4", "ip6"}
lookupCtx, cancel := context.WithCancel(ctx)
for _, network := range networks {
go func(ctx context.Context, network string, results chan<- result) {
for range tries {
ips, err := s.resolver.LookupIP(ctx, network, hostname)
if err != nil {
if strings.HasSuffix(err.Error(), "no such host") {
results <- result{network: network} // no IP address for this network
return
}
continue // retry
}
results <- result{network: network, ips: ips, err: err}
return
}
}(lookupCtx, network, results)
}
for range networks {
result := <-results
if result.err != nil {
if err == nil {
cancel()
err = fmt.Errorf("looking up %s addresses: %w", result.network, result.err)
}
continue
}
switch result.network {
case "ip4":
ipv4 = make([]netip.Addr, len(result.ips))
for i, ip := range result.ips {
ipv4[i] = netip.AddrFrom4([4]byte(ip))
}
case "ip6":
ipv6 = make([]netip.Addr, len(result.ips))
for i, ip := range result.ips {
ipv6[i] = netip.AddrFrom16([16]byte(ip))
}
} }
} }
return nil, nil, err cancel()
}
func (s *Service) lookupIPs(ctx context.Context, hostname string) ( return ipv4, ipv6, err
ipv4, ipv6 []netip.Addr, err error) {
netIPs, err := s.resolver.LookupIP(ctx, "ip", hostname)
if err != nil {
return nil, nil, err
}
ipv4 = make([]netip.Addr, 0, len(netIPs))
ipv6 = make([]netip.Addr, 0, len(netIPs))
for _, netIP := range netIPs {
switch {
case netIP == nil:
case netIP.To4() != nil:
ipv4 = append(ipv4, netip.AddrFrom4([4]byte(netIP.To4())))
default: // IPv6
ipv6 = append(ipv6, netip.AddrFrom16([16]byte(netIP.To16())))
}
}
return ipv4, ipv6, nil
} }
func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) { func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
@@ -101,7 +126,8 @@ func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
} }
func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) ( func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
ip, ipv4, ipv6 netip.Addr, errors []error) { ip, ipv4, ipv6 netip.Addr, errors []error,
) {
var err error var err error
if doIP { if doIP {
ip, err = tryAndRepeatGettingIP(ctx, s.ipGetter.IP, s.logger, ipversion.IP4or6) ip, err = tryAndRepeatGettingIP(ctx, s.ipGetter.IP, s.logger, ipversion.IP4or6)
@@ -125,7 +151,8 @@ func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
} }
func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords.Record, func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords.Record,
ip, ipv4, ipv6 netip.Addr) (recordIDs map[uint]struct{}) { ip, ipv4, ipv6 netip.Addr,
) (recordIDs map[uint]struct{}) {
recordIDs = make(map[uint]struct{}) recordIDs = make(map[uint]struct{})
for i, record := range records { for i, record := range records {
shouldUpdate := s.shouldUpdateRecord(ctx, record, ip, ipv4, ipv6) shouldUpdate := s.shouldUpdateRecord(ctx, record, ip, ipv4, ipv6)
@@ -138,7 +165,8 @@ func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords
} }
func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Record, func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Record,
ip, ipv4, ipv6 netip.Addr) (update bool) { ip, ipv4, ipv6 netip.Addr,
) (update bool) {
now := s.timeNow() now := s.timeNow()
isWithinCooldown := now.Sub(record.History.GetSuccessTime()) < s.cooldown isWithinCooldown := now.Sub(record.History.GetSuccessTime()) < s.cooldown
@@ -178,7 +206,8 @@ func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Reco
} }
func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversion.IPVersion, func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversion.IPVersion,
lastIP, publicIP netip.Addr) (update bool) { lastIP, publicIP netip.Addr,
) (update bool) {
ipKind := ipVersionToIPKind(ipVersion) ipKind := ipVersionToIPKind(ipVersion)
if publicIP.IsValid() && publicIP.Compare(lastIP) != 0 { if publicIP.IsValid() && publicIP.Compare(lastIP) != 0 {
s.logInfoNoLookupUpdate(hostname, ipKind, lastIP, publicIP) s.logInfoNoLookupUpdate(hostname, ipKind, lastIP, publicIP)
@@ -189,7 +218,8 @@ func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversio
} }
func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname string, func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname string,
ipVersion ipversion.IPVersion, publicIP netip.Addr) (update bool) { ipVersion ipversion.IPVersion, publicIP netip.Addr,
) (update bool) {
const tries = 5 const tries = 5
recordIPv4s, recordIPv6s, err := s.lookupIPsResilient(ctx, hostname, tries) recordIPv4s, recordIPv6s, err := s.lookupIPsResilient(ctx, hostname, tries)
if err != nil { if err != nil {
@@ -375,7 +405,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, startErr er
} }
func (s *Service) run(ctx context.Context, ready chan<- struct{}, func (s *Service) run(ctx context.Context, ready chan<- struct{},
done chan<- struct{}) { done chan<- struct{},
) {
defer close(done) defer close(done)
ticker := time.NewTicker(s.period) ticker := time.NewTicker(s.period)
close(ready) close(ready)