fix(update): do not update if public IP is part of multiple IPs found in records

This commit is contained in:
Quentin McGaw
2024-11-21 22:02:09 +00:00
parent 6c3490f5c5
commit 9ba1633300
2 changed files with 54 additions and 28 deletions

View File

@@ -32,14 +32,22 @@ func (s *Service) logInfoNoLookupUpdate(hostname, ipKind string, lastIP, ip neti
ipKind, hostname, lastIP, ipKind, ip))
}
func (s *Service) logDebugLookupSkip(hostname, ipKind string, recordIP, ip netip.Addr) {
func (s *Service) logDebugLookupSkip(hostname, ipKind string, recordIPs []netip.Addr, ip netip.Addr) {
s.logger.Debug(fmt.Sprintf("%s address of %s is %s and your %s address"+
" is %s, skipping update", ipKind, hostname, recordIP, ipKind, ip))
" is %s, skipping update", ipKind, hostname, ipsToString(recordIPs), ipKind, ip))
}
func (s *Service) logInfoLookupUpdate(hostname, ipKind string, recordIP, ip netip.Addr) {
s.logger.Info(fmt.Sprintf("%s address of %s is %s and your %s address is %s",
ipKind, hostname, recordIP, ipKind, ip))
func (s *Service) logInfoLookupUpdate(hostname, ipKind string, recordIPs []netip.Addr, ip netip.Addr) {
s.logger.Info(fmt.Sprintf("%s address of %s is %s and your %s address is %s",
ipKind, hostname, ipsToString(recordIPs), ipKind, ip))
}
func ipsToString(ips []netip.Addr) string {
ipStrings := make([]string, len(ips))
for i, ip := range ips {
ipStrings[i] = ip.String()
}
return strings.Join(ipStrings, ", ")
}
type joinedErrors struct { //nolint:errname

View File

@@ -52,7 +52,7 @@ func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
}
func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) (
ipv4 netip.Addr, ipv6 netip.Addr, err error,
ipv4, ipv6 []netip.Addr, err error,
) {
for range tries {
ipv4, ipv6, err = s.lookupIPs(ctx, hostname)
@@ -60,34 +60,29 @@ func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries
return ipv4, ipv6, nil
}
}
return netip.Addr{}, netip.Addr{}, err
return nil, nil, err
}
func (s *Service) lookupIPs(ctx context.Context, hostname string) (
ipv4 netip.Addr, ipv6 netip.Addr, err error,
ipv4, ipv6 []netip.Addr, err error,
) {
netIPs, err := s.resolver.LookupIP(ctx, "ip", hostname)
if err != nil {
return netip.Addr{}, netip.Addr{}, err
return nil, nil, err
}
ips := make([]netip.Addr, len(netIPs))
for i, netIP := range netIPs {
ipv4 = make([]netip.Addr, 0, len(netIPs))
ipv6 = make([]netip.Addr, 0, len(netIPs))
for _, netIP := range netIPs {
switch {
case netIP == nil:
case netIP.To4() != nil:
ips[i] = netip.AddrFrom4([4]byte(netIP.To4()))
ipv4 = append(ipv4, netip.AddrFrom4([4]byte(netIP.To4())))
default: // IPv6
ips[i] = netip.AddrFrom16([16]byte(netIP.To16()))
ipv6 = append(ipv6, netip.AddrFrom16([16]byte(netIP.To16())))
}
}
for _, ip := range ips {
if ip.Is6() {
ipv6 = ip
} else {
ipv4 = ip
}
}
return ipv4, ipv6, nil
}
@@ -204,7 +199,7 @@ func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname str
ipVersion ipversion.IPVersion, publicIP netip.Addr,
) (update bool) {
const tries = 5
recordIPv4, recordIPv6, err := s.lookupIPsResilient(ctx, hostname, tries)
recordIPv4s, recordIPv6s, err := s.lookupIPsResilient(ctx, hostname, tries)
if err != nil {
ctxErr := ctx.Err()
if ctxErr != nil {
@@ -216,18 +211,27 @@ func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname str
}
ipKind := ipVersionToIPKind(ipVersion)
recordIP := recordIPv4
recordIPs := recordIPv4s
if publicIP.Is6() {
recordIP = recordIPv6
recordIPs = recordIPv6s
}
recordIP = getIPMatchingVersion(recordIP, recordIPv4, recordIPv6, ipVersion)
recordIPs = getIPsMatchingVersion(recordIPs, recordIPv4s, recordIPv6s, ipVersion)
if publicIP.IsValid() && publicIP.Compare(recordIP) != 0 {
if publicIP.IsValid() && !ipsContainsIP(recordIPs, publicIP) {
// Note if the recordIP is not valid (not found), we want to update.
s.logInfoLookupUpdate(hostname, ipKind, recordIP, publicIP)
s.logInfoLookupUpdate(hostname, ipKind, recordIPs, publicIP)
return true
}
s.logDebugLookupSkip(hostname, ipKind, recordIP, publicIP)
s.logDebugLookupSkip(hostname, ipKind, recordIPs, publicIP)
return false
}
func ipsContainsIP(ips []netip.Addr, ip netip.Addr) bool {
for _, ip2 := range ips {
if ip.Compare(ip2) == 0 {
return true
}
}
return false
}
@@ -239,8 +243,22 @@ func getIPMatchingVersion(ip, ipv4, ipv6 netip.Addr, ipVersion ipversion.IPVersi
return ipv4
case ipversion.IP6:
return ipv6
default:
panic(fmt.Sprintf("invalid IP version %s", ipVersion))
}
}
func getIPsMatchingVersion(ip, ipv4, ipv6 []netip.Addr, ipVersion ipversion.IPVersion) []netip.Addr {
switch ipVersion {
case ipversion.IP4or6:
return ip
case ipversion.IP4:
return ipv4
case ipversion.IP6:
return ipv6
default:
panic(fmt.Sprintf("invalid IP version %s", ipVersion))
}
return netip.Addr{}
}
func setInitialUpToDateStatus(db Database, id uint, updateIP netip.Addr, now time.Time) error {