fix(update): do not update if public IP is part of multiple IPs found in records

This commit is contained in:
Quentin McGaw
2024-11-21 22:02:09 +00:00
parent 6c3490f5c5
commit 9ba1633300
2 changed files with 54 additions and 28 deletions

View File

@@ -32,14 +32,22 @@ func (s *Service) logInfoNoLookupUpdate(hostname, ipKind string, lastIP, ip neti
ipKind, hostname, lastIP, ipKind, ip)) ipKind, hostname, lastIP, ipKind, ip))
} }
func (s *Service) logDebugLookupSkip(hostname, ipKind string, recordIP, ip netip.Addr) { func (s *Service) logDebugLookupSkip(hostname, ipKind string, recordIPs []netip.Addr, ip netip.Addr) {
s.logger.Debug(fmt.Sprintf("%s address of %s is %s and your %s address"+ s.logger.Debug(fmt.Sprintf("%s address of %s is %s and your %s address"+
" is %s, skipping update", ipKind, hostname, recordIP, ipKind, ip)) " is %s, skipping update", ipKind, hostname, ipsToString(recordIPs), ipKind, ip))
} }
func (s *Service) logInfoLookupUpdate(hostname, ipKind string, recordIP, ip netip.Addr) { func (s *Service) logInfoLookupUpdate(hostname, ipKind string, recordIPs []netip.Addr, ip netip.Addr) {
s.logger.Info(fmt.Sprintf("%s address of %s is %s and your %s address is %s", s.logger.Info(fmt.Sprintf("%s address of %s is %s and your %s address is %s",
ipKind, hostname, recordIP, ipKind, ip)) ipKind, hostname, ipsToString(recordIPs), ipKind, ip))
}
func ipsToString(ips []netip.Addr) string {
ipStrings := make([]string, len(ips))
for i, ip := range ips {
ipStrings[i] = ip.String()
}
return strings.Join(ipStrings, ", ")
} }
type joinedErrors struct { //nolint:errname type joinedErrors struct { //nolint:errname

View File

@@ -52,7 +52,7 @@ func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
} }
func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) ( func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) (
ipv4 netip.Addr, ipv6 netip.Addr, err error, ipv4, ipv6 []netip.Addr, err error,
) { ) {
for range tries { for range tries {
ipv4, ipv6, err = s.lookupIPs(ctx, hostname) ipv4, ipv6, err = s.lookupIPs(ctx, hostname)
@@ -60,34 +60,29 @@ func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries
return ipv4, ipv6, nil return ipv4, ipv6, nil
} }
} }
return netip.Addr{}, netip.Addr{}, err return nil, nil, err
} }
func (s *Service) lookupIPs(ctx context.Context, hostname string) ( func (s *Service) lookupIPs(ctx context.Context, hostname string) (
ipv4 netip.Addr, ipv6 netip.Addr, err error, ipv4, ipv6 []netip.Addr, err error,
) { ) {
netIPs, err := s.resolver.LookupIP(ctx, "ip", hostname) netIPs, err := s.resolver.LookupIP(ctx, "ip", hostname)
if err != nil { if err != nil {
return netip.Addr{}, netip.Addr{}, err return nil, nil, err
} }
ips := make([]netip.Addr, len(netIPs))
for i, netIP := range netIPs { ipv4 = make([]netip.Addr, 0, len(netIPs))
ipv6 = make([]netip.Addr, 0, len(netIPs))
for _, netIP := range netIPs {
switch { switch {
case netIP == nil: case netIP == nil:
case netIP.To4() != nil: case netIP.To4() != nil:
ips[i] = netip.AddrFrom4([4]byte(netIP.To4())) ipv4 = append(ipv4, netip.AddrFrom4([4]byte(netIP.To4())))
default: // IPv6 default: // IPv6
ips[i] = netip.AddrFrom16([16]byte(netIP.To16())) ipv6 = append(ipv6, netip.AddrFrom16([16]byte(netIP.To16())))
} }
} }
for _, ip := range ips {
if ip.Is6() {
ipv6 = ip
} else {
ipv4 = ip
}
}
return ipv4, ipv6, nil return ipv4, ipv6, nil
} }
@@ -204,7 +199,7 @@ func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname str
ipVersion ipversion.IPVersion, publicIP netip.Addr, ipVersion ipversion.IPVersion, publicIP netip.Addr,
) (update bool) { ) (update bool) {
const tries = 5 const tries = 5
recordIPv4, recordIPv6, err := s.lookupIPsResilient(ctx, hostname, tries) recordIPv4s, recordIPv6s, err := s.lookupIPsResilient(ctx, hostname, tries)
if err != nil { if err != nil {
ctxErr := ctx.Err() ctxErr := ctx.Err()
if ctxErr != nil { if ctxErr != nil {
@@ -216,18 +211,27 @@ func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname str
} }
ipKind := ipVersionToIPKind(ipVersion) ipKind := ipVersionToIPKind(ipVersion)
recordIP := recordIPv4 recordIPs := recordIPv4s
if publicIP.Is6() { if publicIP.Is6() {
recordIP = recordIPv6 recordIPs = recordIPv6s
} }
recordIP = getIPMatchingVersion(recordIP, recordIPv4, recordIPv6, ipVersion) recordIPs = getIPsMatchingVersion(recordIPs, recordIPv4s, recordIPv6s, ipVersion)
if publicIP.IsValid() && publicIP.Compare(recordIP) != 0 { if publicIP.IsValid() && !ipsContainsIP(recordIPs, publicIP) {
// Note if the recordIP is not valid (not found), we want to update. // Note if the recordIP is not valid (not found), we want to update.
s.logInfoLookupUpdate(hostname, ipKind, recordIP, publicIP) s.logInfoLookupUpdate(hostname, ipKind, recordIPs, publicIP)
return true return true
} }
s.logDebugLookupSkip(hostname, ipKind, recordIP, publicIP) s.logDebugLookupSkip(hostname, ipKind, recordIPs, publicIP)
return false
}
func ipsContainsIP(ips []netip.Addr, ip netip.Addr) bool {
for _, ip2 := range ips {
if ip.Compare(ip2) == 0 {
return true
}
}
return false return false
} }
@@ -239,8 +243,22 @@ func getIPMatchingVersion(ip, ipv4, ipv6 netip.Addr, ipVersion ipversion.IPVersi
return ipv4 return ipv4
case ipversion.IP6: case ipversion.IP6:
return ipv6 return ipv6
default:
panic(fmt.Sprintf("invalid IP version %s", ipVersion))
}
}
func getIPsMatchingVersion(ip, ipv4, ipv6 []netip.Addr, ipVersion ipversion.IPVersion) []netip.Addr {
switch ipVersion {
case ipversion.IP4or6:
return ip
case ipversion.IP4:
return ipv4
case ipversion.IP6:
return ipv6
default:
panic(fmt.Sprintf("invalid IP version %s", ipVersion))
} }
return netip.Addr{}
} }
func setInitialUpToDateStatus(db Database, id uint, updateIP netip.Addr, now time.Time) error { func setInitialUpToDateStatus(db Database, id uint, updateIP netip.Addr, now time.Time) error {