mirror of
https://github.com/qdm12/ddns-updater.git
synced 2026-04-05 00:43:53 -04:00
345 lines
9.6 KiB
Go
345 lines
9.6 KiB
Go
package update
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/netip"
|
|
"time"
|
|
|
|
"github.com/qdm12/ddns-updater/internal/constants"
|
|
"github.com/qdm12/ddns-updater/internal/models"
|
|
librecords "github.com/qdm12/ddns-updater/internal/records"
|
|
"github.com/qdm12/ddns-updater/pkg/publicip/ipversion"
|
|
)
|
|
|
|
type Runner struct {
|
|
period time.Duration
|
|
db Database
|
|
updater UpdaterInterface
|
|
force chan struct{}
|
|
forceResult chan []error
|
|
cooldown time.Duration
|
|
resolver LookupIPer
|
|
ipGetter PublicIPFetcher
|
|
logger Logger
|
|
timeNow func() time.Time
|
|
hioClient HealthchecksIOClient
|
|
}
|
|
|
|
func NewRunner(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
|
|
period time.Duration, cooldown time.Duration, logger Logger, resolver LookupIPer,
|
|
timeNow func() time.Time, hioClient HealthchecksIOClient) *Runner {
|
|
return &Runner{
|
|
period: period,
|
|
db: db,
|
|
updater: updater,
|
|
force: make(chan struct{}),
|
|
forceResult: make(chan []error),
|
|
cooldown: cooldown,
|
|
resolver: resolver,
|
|
ipGetter: ipGetter,
|
|
logger: logger,
|
|
timeNow: timeNow,
|
|
hioClient: hioClient,
|
|
}
|
|
}
|
|
|
|
func (r *Runner) lookupIPsResilient(ctx context.Context, hostname string, tries int) (
|
|
ipv4 netip.Addr, ipv6 netip.Addr, err error) {
|
|
for i := 0; i < tries; i++ {
|
|
ipv4, ipv6, err = r.lookupIPs(ctx, hostname)
|
|
if err == nil {
|
|
return ipv4, ipv6, nil
|
|
}
|
|
}
|
|
return netip.Addr{}, netip.Addr{}, err
|
|
}
|
|
|
|
func (r *Runner) lookupIPs(ctx context.Context, hostname string) (
|
|
ipv4 netip.Addr, ipv6 netip.Addr, err error) {
|
|
netIPs, err := r.resolver.LookupIP(ctx, "ip", hostname)
|
|
if err != nil {
|
|
return netip.Addr{}, netip.Addr{}, err
|
|
}
|
|
ips := make([]netip.Addr, len(netIPs))
|
|
for i, netIP := range netIPs {
|
|
switch {
|
|
case netIP == nil:
|
|
case netIP.To4() != nil:
|
|
ips[i] = netip.AddrFrom4([4]byte(netIP.To4()))
|
|
default: // IPv6
|
|
ips[i] = netip.AddrFrom16([16]byte(netIP.To16()))
|
|
}
|
|
}
|
|
|
|
for _, ip := range ips {
|
|
if ip.Is6() {
|
|
ipv6 = ip
|
|
} else {
|
|
ipv4 = ip
|
|
}
|
|
}
|
|
return ipv4, ipv6, nil
|
|
}
|
|
|
|
func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
|
|
for _, record := range records {
|
|
switch record.Provider.IPVersion() {
|
|
case ipversion.IP4or6:
|
|
doIP = true
|
|
case ipversion.IP4:
|
|
doIPv4 = true
|
|
case ipversion.IP6:
|
|
doIPv6 = true
|
|
}
|
|
if doIP && doIPv4 && doIPv6 {
|
|
return true, true, true
|
|
}
|
|
}
|
|
return doIP, doIPv4, doIPv6
|
|
}
|
|
|
|
func (r *Runner) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
|
|
ip, ipv4, ipv6 netip.Addr, errors []error) {
|
|
var err error
|
|
if doIP {
|
|
ip, err = tryAndRepeatGettingIP(ctx, r.ipGetter.IP, r.logger, ipversion.IP4or6)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
if doIPv4 {
|
|
ipv4, err = tryAndRepeatGettingIP(ctx, r.ipGetter.IP4, r.logger, ipversion.IP4)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
if doIPv6 {
|
|
ipv6, err = tryAndRepeatGettingIP(ctx, r.ipGetter.IP6, r.logger, ipversion.IP6)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
return ip, ipv4, ipv6, errors
|
|
}
|
|
|
|
func (r *Runner) getRecordIDsToUpdate(ctx context.Context, records []librecords.Record,
|
|
ip, ipv4, ipv6 netip.Addr) (recordIDs map[uint]struct{}) {
|
|
recordIDs = make(map[uint]struct{})
|
|
for i, record := range records {
|
|
shouldUpdate := r.shouldUpdateRecord(ctx, record, ip, ipv4, ipv6)
|
|
if shouldUpdate {
|
|
id := uint(i)
|
|
recordIDs[id] = struct{}{}
|
|
}
|
|
}
|
|
return recordIDs
|
|
}
|
|
|
|
func (r *Runner) shouldUpdateRecord(ctx context.Context, record librecords.Record,
|
|
ip, ipv4, ipv6 netip.Addr) (update bool) {
|
|
now := r.timeNow()
|
|
isWithinBanPeriod := record.LastBan != nil && now.Sub(*record.LastBan) < time.Hour
|
|
isWithinCooldown := now.Sub(record.History.GetSuccessTime()) < r.cooldown
|
|
if isWithinBanPeriod || isWithinCooldown {
|
|
domain := record.Provider.BuildDomainName()
|
|
r.logger.Debug("record " + domain + " is within ban period or cooldown period, skipping update")
|
|
return false
|
|
}
|
|
|
|
hostname := record.Provider.BuildDomainName()
|
|
ipVersion := record.Provider.IPVersion()
|
|
publicIP := getIPMatchingVersion(ip, ipv4, ipv6, ipVersion)
|
|
|
|
if !publicIP.IsValid() {
|
|
r.logger.Warn(fmt.Sprintf("Skipping update for %s because %s address was not found",
|
|
hostname, ipVersionToIPKind(ipVersion)))
|
|
return false
|
|
} else if publicIP.Is6() {
|
|
publicIP = ipv6WithSuffix(publicIP, record.Provider.IPv6Suffix())
|
|
}
|
|
|
|
if record.Provider.Proxied() {
|
|
lastIP := record.History.GetCurrentIP() // can be nil
|
|
return r.shouldUpdateRecordNoLookup(hostname, ipVersion, lastIP, publicIP)
|
|
}
|
|
return r.shouldUpdateRecordWithLookup(ctx, hostname, ipVersion, publicIP)
|
|
}
|
|
|
|
func (r *Runner) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversion.IPVersion,
|
|
lastIP, publicIP netip.Addr) (update bool) {
|
|
ipKind := ipVersionToIPKind(ipVersion)
|
|
if publicIP.IsValid() && publicIP.Compare(lastIP) != 0 {
|
|
r.logInfoNoLookupUpdate(hostname, ipKind, lastIP, publicIP)
|
|
return true
|
|
}
|
|
r.logDebugNoLookupSkip(hostname, ipKind, lastIP, publicIP)
|
|
return false
|
|
}
|
|
|
|
func (r *Runner) shouldUpdateRecordWithLookup(ctx context.Context, hostname string,
|
|
ipVersion ipversion.IPVersion, publicIP netip.Addr) (update bool) {
|
|
const tries = 5
|
|
recordIPv4, recordIPv6, err := r.lookupIPsResilient(ctx, hostname, tries)
|
|
if err != nil {
|
|
ctxErr := ctx.Err()
|
|
if ctxErr != nil {
|
|
r.logger.Warn("DNS resolution of " + hostname + ": " + ctxErr.Error())
|
|
return false
|
|
}
|
|
r.logger.Warn("cannot DNS resolve " + hostname + " after " +
|
|
fmt.Sprint(tries) + " tries: " + err.Error()) // update anyway
|
|
}
|
|
|
|
ipKind := ipVersionToIPKind(ipVersion)
|
|
recordIP := recordIPv4
|
|
if publicIP.Is6() {
|
|
recordIP = recordIPv6
|
|
}
|
|
recordIP = getIPMatchingVersion(recordIP, recordIPv4, recordIPv6, ipVersion)
|
|
|
|
if publicIP.IsValid() && publicIP.Compare(recordIP) != 0 {
|
|
// Note if the recordIP is not valid (not found), we want to update.
|
|
r.logInfoLookupUpdate(hostname, ipKind, recordIP, publicIP)
|
|
return true
|
|
}
|
|
r.logDebugLookupSkip(hostname, ipKind, recordIP, publicIP)
|
|
return false
|
|
}
|
|
|
|
func getIPMatchingVersion(ip, ipv4, ipv6 netip.Addr, ipVersion ipversion.IPVersion) netip.Addr {
|
|
switch ipVersion {
|
|
case ipversion.IP4or6:
|
|
return ip
|
|
case ipversion.IP4:
|
|
return ipv4
|
|
case ipversion.IP6:
|
|
return ipv6
|
|
}
|
|
return netip.Addr{}
|
|
}
|
|
|
|
func setInitialUpToDateStatus(db Database, id uint, updateIP netip.Addr, now time.Time) error {
|
|
record, err := db.Select(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
record.Status = constants.UPTODATE
|
|
record.Time = now
|
|
if !record.History.GetCurrentIP().IsValid() {
|
|
record.History = append(record.History, models.HistoryEvent{
|
|
IP: updateIP,
|
|
Time: now,
|
|
})
|
|
}
|
|
return db.Update(id, record)
|
|
}
|
|
|
|
func setInitialPublicIPFailStatus(db Database, id uint, now time.Time) error {
|
|
record, err := db.Select(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
record.Status = constants.FAIL
|
|
record.Message = "public IP address not found"
|
|
record.Time = now
|
|
return db.Update(id, record)
|
|
}
|
|
|
|
func (r *Runner) updateNecessary(ctx context.Context) (errors []error) {
|
|
records := r.db.SelectAll()
|
|
doIP, doIPv4, doIPv6 := doIPVersion(records)
|
|
r.logger.Debug(fmt.Sprintf("configured to fetch IP: v4 or v6: %t, v4: %t, v6: %t", doIP, doIPv4, doIPv6))
|
|
ip, ipv4, ipv6, errors := r.getNewIPs(ctx, doIP, doIPv4, doIPv6)
|
|
r.logger.Debug(fmt.Sprintf("your public IP address are: v4 or v6: %s, v4: %s, v6: %s", ip, ipv4, ipv6))
|
|
for _, err := range errors {
|
|
r.logger.Error(err.Error())
|
|
}
|
|
|
|
recordIDs := r.getRecordIDsToUpdate(ctx, records, ip, ipv4, ipv6)
|
|
|
|
// Current time is used to set initial states for records already
|
|
// up to date or in the fail state due to the public IP not found.
|
|
// No need to have it queried within the next for loop since each
|
|
// iteration is fast and has no IO involved.
|
|
now := r.timeNow()
|
|
|
|
for i, record := range records {
|
|
id := uint(i)
|
|
_, requireUpdate := recordIDs[id]
|
|
if requireUpdate || record.Status != constants.UNSET {
|
|
continue
|
|
}
|
|
|
|
ipVersion := record.Provider.IPVersion()
|
|
updateIP := getIPMatchingVersion(ip, ipv4, ipv6, ipVersion)
|
|
if !updateIP.IsValid() {
|
|
// warning was already logged in getRecordIDsToUpdate
|
|
err := setInitialPublicIPFailStatus(r.db, id, now)
|
|
if err != nil {
|
|
err = fmt.Errorf("setting initial public IP fail status: %w", err)
|
|
errors = append(errors, err)
|
|
r.logger.Error(err.Error())
|
|
}
|
|
continue
|
|
} else if updateIP.Is6() {
|
|
updateIP = ipv6WithSuffix(updateIP, record.Provider.IPv6Suffix())
|
|
}
|
|
|
|
err := setInitialUpToDateStatus(r.db, id, updateIP, now)
|
|
if err != nil {
|
|
err = fmt.Errorf("setting initial up to date status: %w", err)
|
|
errors = append(errors, err)
|
|
r.logger.Error(err.Error())
|
|
}
|
|
}
|
|
for id := range recordIDs {
|
|
record := records[id]
|
|
updateIP := getIPMatchingVersion(ip, ipv4, ipv6, record.Provider.IPVersion())
|
|
// Note: each record id has a matching valid public IP address.
|
|
if updateIP.Is6() {
|
|
updateIP = ipv6WithSuffix(updateIP, record.Provider.IPv6Suffix())
|
|
}
|
|
r.logger.Info("Updating record " + record.Provider.String() + " to use " + updateIP.String())
|
|
err := r.updater.Update(ctx, id, updateIP)
|
|
if err != nil {
|
|
errors = append(errors, err)
|
|
r.logger.Error(err.Error())
|
|
}
|
|
}
|
|
|
|
err := r.hioClient.Ping(ctx)
|
|
if err != nil {
|
|
r.logger.Error("pinging healthchecks.io failed: " + err.Error())
|
|
}
|
|
|
|
return errors
|
|
}
|
|
|
|
func (r *Runner) Run(ctx context.Context, done chan<- struct{}) {
|
|
defer close(done)
|
|
ticker := time.NewTicker(r.period)
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
r.updateNecessary(ctx)
|
|
case <-r.force:
|
|
r.forceResult <- r.updateNecessary(ctx)
|
|
case <-ctx.Done():
|
|
ticker.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *Runner) ForceUpdate(ctx context.Context) (errs []error) {
|
|
r.force <- struct{}{}
|
|
|
|
select {
|
|
case errs = <-r.forceResult:
|
|
case <-ctx.Done():
|
|
errs = []error{ctx.Err()}
|
|
}
|
|
return errs
|
|
}
|