Files
ddns-updater-qdm12-3/internal/provider/providers/cloudflare/provider.go
2024-10-20 13:12:33 +00:00

414 lines
12 KiB
Go

package cloudflare
import (
"bytes"
"context"
"encoding/json"
stderrors "errors"
"fmt"
"net/http"
"net/netip"
"net/url"
"regexp"
"strings"
"github.com/qdm12/ddns-updater/internal/models"
"github.com/qdm12/ddns-updater/internal/provider/constants"
"github.com/qdm12/ddns-updater/internal/provider/errors"
"github.com/qdm12/ddns-updater/internal/provider/headers"
"github.com/qdm12/ddns-updater/internal/provider/utils"
"github.com/qdm12/ddns-updater/pkg/publicip/ipversion"
)
type Provider struct {
domain string
owner string
ipVersion ipversion.IPVersion
ipv6Suffix netip.Prefix
key string
token string
email string
userServiceKey string
zoneIdentifier string
proxied bool
ttl uint32
}
func New(data json.RawMessage, domain, owner string,
ipVersion ipversion.IPVersion, ipv6Suffix netip.Prefix) (
p *Provider, err error,
) {
extraSettings := struct {
Key string `json:"key"`
Token string `json:"token"`
Email string `json:"email"`
UserServiceKey string `json:"user_service_key"`
ZoneIdentifier string `json:"zone_identifier"`
Proxied bool `json:"proxied"`
TTL uint32 `json:"ttl"`
}{}
err = json.Unmarshal(data, &extraSettings)
if err != nil {
return nil, err
}
err = validateSettings(domain, extraSettings.Email, extraSettings.Key, extraSettings.UserServiceKey,
extraSettings.ZoneIdentifier, extraSettings.TTL)
if err != nil {
return nil, fmt.Errorf("validating provider specific settings: %w", err)
}
return &Provider{
domain: domain,
owner: owner,
ipVersion: ipVersion,
ipv6Suffix: ipv6Suffix,
key: extraSettings.Key,
token: extraSettings.Token,
email: extraSettings.Email,
userServiceKey: extraSettings.UserServiceKey,
zoneIdentifier: extraSettings.ZoneIdentifier,
proxied: extraSettings.Proxied,
ttl: extraSettings.TTL,
}, nil
}
var (
keyRegex = regexp.MustCompile(`^[a-zA-Z0-9]+$`)
userServiceKeyRegex = regexp.MustCompile(`^v1\.0.+$`)
regexEmail = regexp.MustCompile(`[a-zA-Z0-9-_.+]+@[a-zA-Z0-9-_.]+\.[a-zA-Z]{2,10}`)
)
func validateSettings(domain, email, key, userServiceKey, zoneIdentifier string, ttl uint32) (err error) {
err = utils.CheckDomain(domain)
if err != nil {
return fmt.Errorf("%w: %w", errors.ErrDomainNotValid, err)
}
switch {
case email != "", key != "": // email and key must be provided
switch {
case !keyRegex.MatchString(key):
return fmt.Errorf("%w: key %q does not match regex %q",
errors.ErrKeyNotValid, key, keyRegex)
case !regexEmail.MatchString(email):
return fmt.Errorf("%w: email %q does not match regex %q",
errors.ErrEmailNotValid, email, regexEmail)
}
case userServiceKey != "": // only user service key
if !userServiceKeyRegex.MatchString(userServiceKey) {
return fmt.Errorf("%w: %q does not match regex %q",
errors.ErrUserServiceKeyNotValid, userServiceKey, userServiceKeyRegex)
}
default: // constants.API token only
}
switch {
case zoneIdentifier == "":
return fmt.Errorf("%w", errors.ErrZoneIdentifierNotSet)
case ttl == 0:
return fmt.Errorf("%w", errors.ErrTTLNotSet)
}
return nil
}
func (p *Provider) String() string {
return utils.ToString(p.domain, p.owner, constants.Cloudflare, p.ipVersion)
}
func (p *Provider) Domain() string {
return p.domain
}
func (p *Provider) Owner() string {
return p.owner
}
func (p *Provider) IPVersion() ipversion.IPVersion {
return p.ipVersion
}
func (p *Provider) IPv6Suffix() netip.Prefix {
return p.ipv6Suffix
}
func (p *Provider) Proxied() bool {
return p.proxied
}
func (p *Provider) BuildDomainName() string {
return utils.BuildDomainName(p.owner, p.domain)
}
func (p *Provider) HTML() models.HTMLRow {
return models.HTMLRow{
Domain: fmt.Sprintf("<a href=\"http://%s\">%s</a>", p.BuildDomainName(), p.BuildDomainName()),
Owner: p.Owner(),
Provider: "<a href=\"https://www.cloudflare.com\">Cloudflare</a>",
IPVersion: p.ipVersion.String(),
}
}
func (p *Provider) setHeaders(request *http.Request) {
headers.SetUserAgent(request)
headers.SetContentType(request, "application/json")
headers.SetAccept(request, "application/json")
switch {
case p.token != "":
headers.SetAuthBearer(request, p.token)
case p.userServiceKey != "":
request.Header.Set("X-Auth-User-Service-Key", p.userServiceKey)
case p.email != "" && p.key != "":
request.Header.Set("X-Auth-Email", p.email)
request.Header.Set("X-Auth-Key", p.key)
}
}
// Obtain domain ID.
// See https://api.cloudflare.com/#dns-records-for-a-zone-list-dns-records.
func (p *Provider) getRecordID(ctx context.Context, client *http.Client, newIP netip.Addr) (
identifier string, upToDate bool, err error,
) {
recordType := constants.A
if newIP.Is6() {
recordType = constants.AAAA
}
u := url.URL{
Scheme: "https",
Host: "api.cloudflare.com",
Path: fmt.Sprintf("/client/v4/zones/%s/dns_records", p.zoneIdentifier),
}
values := url.Values{}
values.Set("type", recordType)
values.Set("name", utils.BuildURLQueryHostname(p.owner, p.domain))
values.Set("page", "1")
values.Set("per_page", "1")
u.RawQuery = values.Encode()
request, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
if err != nil {
return "", false, fmt.Errorf("creating http request: %w", err)
}
p.setHeaders(request)
response, err := client.Do(request)
if err != nil {
return "", false, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
return "", false, fmt.Errorf("%w: %d: %s",
errors.ErrHTTPStatusNotValid, response.StatusCode, utils.BodyToSingleLine(response.Body))
}
decoder := json.NewDecoder(response.Body)
listRecordsResponse := struct {
Success bool `json:"success"`
Errors []string `json:"errors"`
Result []struct {
ID string `json:"id"`
Content string `json:"content"`
} `json:"result"`
}{}
err = decoder.Decode(&listRecordsResponse)
if err != nil {
return "", false, fmt.Errorf("json decoding response body: %w", err)
}
switch {
case len(listRecordsResponse.Errors) > 0:
return "", false, fmt.Errorf("%w: %s",
errors.ErrUnsuccessful, strings.Join(listRecordsResponse.Errors, ","))
case !listRecordsResponse.Success:
return "", false, fmt.Errorf("%w", errors.ErrUnsuccessful)
case len(listRecordsResponse.Result) == 0:
return "", false, fmt.Errorf("%w", errors.ErrReceivedNoResult)
case len(listRecordsResponse.Result) > 1:
return "", false, fmt.Errorf("%w: %d instead of 1",
errors.ErrResultsCountReceived, len(listRecordsResponse.Result))
case listRecordsResponse.Result[0].Content == newIP.String(): // up to date
return "", true, nil
}
return listRecordsResponse.Result[0].ID, false, nil
}
func (p *Provider) createRecord(ctx context.Context, client *http.Client, ip netip.Addr) (recordID string, err error) {
recordType := constants.A
if ip.Is6() {
recordType = constants.AAAA
}
u := url.URL{
Scheme: "https",
Host: "api.cloudflare.com",
Path: fmt.Sprintf("/client/v4/zones/%s/dns_records", p.zoneIdentifier),
}
requestData := struct {
Type string `json:"type"` // constants.A or constants.AAAA depending on ip address given
Name string `json:"name"` // DNS record name i.e. example.com
Content string `json:"content"` // ip address
Proxied bool `json:"proxied"` // whether the record is receiving the performance and security benefits of Cloudflare
TTL uint32 `json:"ttl"`
}{
Type: recordType,
Name: utils.BuildURLQueryHostname(p.owner, p.domain),
Content: ip.String(),
Proxied: p.proxied,
TTL: p.ttl,
}
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
err = encoder.Encode(requestData)
if err != nil {
return "", fmt.Errorf("JSON encoding request data: %w", err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), buffer)
if err != nil {
return "", fmt.Errorf("creating http request: %w", err)
}
p.setHeaders(request)
response, err := client.Do(request)
if err != nil {
return "", err
}
defer response.Body.Close()
if response.StatusCode > http.StatusUnsupportedMediaType {
return "", fmt.Errorf("%w: %d: %s",
errors.ErrHTTPStatusNotValid, response.StatusCode, utils.BodyToSingleLine(response.Body))
}
decoder := json.NewDecoder(response.Body)
var parsedJSON struct {
Success bool `json:"success"`
Errors []struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"errors"`
Result struct {
ID string `json:"id"`
} `json:"result"`
}
err = decoder.Decode(&parsedJSON)
if err != nil {
return "", fmt.Errorf("json decoding response body: %w", err)
}
if !parsedJSON.Success {
var errStr string
for _, e := range parsedJSON.Errors {
errStr += fmt.Sprintf("error %d: %s; ", e.Code, e.Message)
}
return "", fmt.Errorf("%w: %s", errors.ErrUnsuccessful, errStr)
}
return parsedJSON.Result.ID, nil
}
func (p *Provider) Update(ctx context.Context, client *http.Client, ip netip.Addr) (newIP netip.Addr, err error) {
recordType := constants.A
if ip.Is6() {
recordType = constants.AAAA
}
identifier, upToDate, err := p.getRecordID(ctx, client, ip)
switch {
case stderrors.Is(err, errors.ErrReceivedNoResult):
identifier, err = p.createRecord(ctx, client, ip)
if err != nil {
return netip.Addr{}, fmt.Errorf("creating record: %w", err)
}
case err != nil:
return netip.Addr{}, fmt.Errorf("getting record id: %w", err)
case upToDate:
return ip, nil
}
u := url.URL{
Scheme: "https",
Host: "api.cloudflare.com",
Path: fmt.Sprintf("/client/v4/zones/%s/dns_records/%s", p.zoneIdentifier, identifier),
}
requestData := struct {
Type string `json:"type"` // constants.A or constants.AAAA depending on ip address given
Name string `json:"name"` // DNS record name i.e. example.com
Content string `json:"content"` // ip address
Proxied bool `json:"proxied"` // whether the record is receiving the performance and security benefits of Cloudflare
TTL uint32 `json:"ttl"`
}{
Type: recordType,
Name: utils.BuildURLQueryHostname(p.owner, p.domain),
Content: ip.String(),
Proxied: p.proxied,
TTL: p.ttl,
}
buffer := bytes.NewBuffer(nil)
encoder := json.NewEncoder(buffer)
err = encoder.Encode(requestData)
if err != nil {
return netip.Addr{}, fmt.Errorf("json encoding request data: %w", err)
}
request, err := http.NewRequestWithContext(ctx, http.MethodPut, u.String(), buffer)
if err != nil {
return netip.Addr{}, fmt.Errorf("creating http request: %w", err)
}
p.setHeaders(request)
response, err := client.Do(request)
if err != nil {
return netip.Addr{}, err
}
defer response.Body.Close()
if response.StatusCode > http.StatusUnsupportedMediaType {
return netip.Addr{}, fmt.Errorf("%w: %d: %s",
errors.ErrHTTPStatusNotValid, response.StatusCode, utils.BodyToSingleLine(response.Body))
}
decoder := json.NewDecoder(response.Body)
var parsedJSON struct {
Success bool `json:"success"`
Errors []struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"errors"`
Result struct {
Content string `json:"content"`
} `json:"result"`
}
err = decoder.Decode(&parsedJSON)
if err != nil {
return netip.Addr{}, fmt.Errorf("json decoding response body: %w", err)
}
if !parsedJSON.Success {
var errStr string
for _, e := range parsedJSON.Errors {
errStr += fmt.Sprintf("error %d: %s; ", e.Code, e.Message)
}
return netip.Addr{}, fmt.Errorf("%w: %s", errors.ErrUnsuccessful, errStr)
}
newIP, err = netip.ParseAddr(parsedJSON.Result.Content)
if err != nil {
return netip.Addr{}, fmt.Errorf("%w: %w", errors.ErrIPReceivedMalformed, err)
} else if newIP.Compare(ip) != 0 {
return netip.Addr{}, fmt.Errorf("%w: sent ip %s to update but received %s",
errors.ErrIPReceivedMismatch, ip, newIP)
}
return newIP, nil
}