Skip to content

Commit

Permalink
fix(update): fetch IPv6 AAAA records and not only IPv4
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Nov 21, 2024
1 parent 75191c2 commit e95816a
Showing 1 changed file with 63 additions and 32 deletions.
95 changes: 63 additions & 32 deletions internal/update/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package update
import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"time"

"github.com/qdm12/ddns-updater/internal/constants"
Expand Down Expand Up @@ -34,7 +36,8 @@ type Service struct {

func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
period time.Duration, cooldown time.Duration, logger Logger, resolver LookupIPer,
timeNow func() time.Time, hioClient HealthchecksIOClient) *Service {
timeNow func() time.Time, hioClient HealthchecksIOClient,
) *Service {
return &Service{
period: period,
db: db,
Expand All @@ -50,37 +53,59 @@ func NewService(db Database, updater UpdaterInterface, ipGetter PublicIPFetcher,
}
}

func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries int) (
ipv4, ipv6 []netip.Addr, err error) {
for i := 0; i < tries; i++ {
ipv4, ipv6, err = s.lookupIPs(ctx, hostname)
if err == nil {
return ipv4, ipv6, nil
}
func (s *Service) lookupIPsResilient(ctx context.Context, hostname string, tries uint) (
ipv4, ipv6 []netip.Addr, err error,
) {
type result struct {
network string
ips []net.IP
err error
}
return nil, nil, err
}

func (s *Service) lookupIPs(ctx context.Context, hostname string) (
ipv4, ipv6 []netip.Addr, err error) {
netIPs, err := s.resolver.LookupIP(ctx, "ip", hostname)
if err != nil {
return nil, nil, err
results := make(chan result)
networks := []string{"ip4", "ip6"}
lookupCtx, cancel := context.WithCancel(ctx)
for _, network := range networks {
go func(ctx context.Context, network string, results chan<- result) {
for range tries {
ips, err := s.resolver.LookupIP(ctx, network, hostname)
if err != nil {
if strings.HasSuffix(err.Error(), "no such host") {
results <- result{network: network} // no IP address for this network
return
}
continue // retry
}
results <- result{network: network, ips: ips, err: err}
return
}
}(lookupCtx, network, results)
}

ipv4 = make([]netip.Addr, 0, len(netIPs))
ipv6 = make([]netip.Addr, 0, len(netIPs))
for _, netIP := range netIPs {
switch {
case netIP == nil:
case netIP.To4() != nil:
ipv4 = append(ipv4, netip.AddrFrom4([4]byte(netIP.To4())))
default: // IPv6
ipv6 = append(ipv6, netip.AddrFrom16([16]byte(netIP.To16())))
for range networks {
result := <-results
if result.err != nil {
if err == nil {
cancel()
err = fmt.Errorf("looking up %s addresses: %w", result.network, result.err)
}
continue
}
switch result.network {
case "ip4":
ipv4 = make([]netip.Addr, len(result.ips))
for i, ip := range result.ips {
ipv4[i] = netip.AddrFrom4([4]byte(ip))
}
case "ip6":
ipv6 = make([]netip.Addr, len(result.ips))
for i, ip := range result.ips {
ipv6[i] = netip.AddrFrom16([16]byte(ip))
}
}
}
cancel()

return ipv4, ipv6, nil
return ipv4, ipv6, err
}

func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
Expand All @@ -101,7 +126,8 @@ func doIPVersion(records []librecords.Record) (doIP, doIPv4, doIPv6 bool) {
}

func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
ip, ipv4, ipv6 netip.Addr, errors []error) {
ip, ipv4, ipv6 netip.Addr, errors []error,
) {
var err error
if doIP {
ip, err = tryAndRepeatGettingIP(ctx, s.ipGetter.IP, s.logger, ipversion.IP4or6)
Expand All @@ -125,7 +151,8 @@ func (s *Service) getNewIPs(ctx context.Context, doIP, doIPv4, doIPv6 bool) (
}

func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords.Record,
ip, ipv4, ipv6 netip.Addr) (recordIDs map[uint]struct{}) {
ip, ipv4, ipv6 netip.Addr,
) (recordIDs map[uint]struct{}) {
recordIDs = make(map[uint]struct{})
for i, record := range records {
shouldUpdate := s.shouldUpdateRecord(ctx, record, ip, ipv4, ipv6)
Expand All @@ -138,7 +165,8 @@ func (s *Service) getRecordIDsToUpdate(ctx context.Context, records []librecords
}

func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Record,
ip, ipv4, ipv6 netip.Addr) (update bool) {
ip, ipv4, ipv6 netip.Addr,
) (update bool) {
now := s.timeNow()

isWithinCooldown := now.Sub(record.History.GetSuccessTime()) < s.cooldown
Expand Down Expand Up @@ -178,7 +206,8 @@ func (s *Service) shouldUpdateRecord(ctx context.Context, record librecords.Reco
}

func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversion.IPVersion,
lastIP, publicIP netip.Addr) (update bool) {
lastIP, publicIP netip.Addr,
) (update bool) {
ipKind := ipVersionToIPKind(ipVersion)
if publicIP.IsValid() && publicIP.Compare(lastIP) != 0 {
s.logInfoNoLookupUpdate(hostname, ipKind, lastIP, publicIP)
Expand All @@ -189,7 +218,8 @@ func (s *Service) shouldUpdateRecordNoLookup(hostname string, ipVersion ipversio
}

func (s *Service) shouldUpdateRecordWithLookup(ctx context.Context, hostname string,
ipVersion ipversion.IPVersion, publicIP netip.Addr) (update bool) {
ipVersion ipversion.IPVersion, publicIP netip.Addr,
) (update bool) {
const tries = 5
recordIPv4s, recordIPv6s, err := s.lookupIPsResilient(ctx, hostname, tries)
if err != nil {
Expand Down Expand Up @@ -375,7 +405,8 @@ func (s *Service) Start(ctx context.Context) (runError <-chan error, startErr er
}

func (s *Service) run(ctx context.Context, ready chan<- struct{},
done chan<- struct{}) {
done chan<- struct{},
) {
defer close(done)
ticker := time.NewTicker(s.period)
close(ready)
Expand Down

0 comments on commit e95816a

Please sign in to comment.