diff --git a/cmd/algons/dnsCmd.go b/cmd/algons/dnsCmd.go index 10b4c18881..97b3b5f334 100644 --- a/cmd/algons/dnsCmd.go +++ b/cmd/algons/dnsCmd.go @@ -105,10 +105,9 @@ var listRecordsCmd = &cobra.Command{ Long: "List the A/SRV entries of the given network", Run: func(cmd *cobra.Command, args []string) { recordType = strings.ToUpper(recordType) - if recordType == "" || recordType == "A" || recordType == "CNAME" || recordType == "SRV" { - listEntries(listNetwork, recordType) - } else { - fmt.Fprintf(os.Stderr, "Invalid recordType specified.\n") + err := listEntries(listNetwork, recordType) + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1) } }, @@ -181,6 +180,19 @@ var exportCmd = &cobra.Command{ }, } +func doAddTXT(from string, to string) error { + cfZoneID, cfToken, err := getClouldflareCredentials() + if err != nil { + return fmt.Errorf("error getting DNS credentials: %v", err) + } + + cloudflareDNS := cloudflare.NewDNS(cfZoneID, cfToken) + + const priority = 1 + const proxied = false + return cloudflareDNS.CreateDNSRecord(context.Background(), "TXT", from, to, cloudflare.AutomaticTTL, priority, proxied) +} + func doAddDNS(from string, to string) (err error) { cfZoneID, cfToken, err := getClouldflareCredentials() if err != nil { @@ -315,7 +327,7 @@ func doDeleteDNS(network string, noPrompt bool, excludePattern string, includePa cloudflareDNS := cloudflare.NewDNS(cfZoneID, cfToken) - idsToDelete := make(map[string]string) // Maps record ID to Name + var idsToDelete []cloudflare.DNSRecordResponseEntry services := []string{"_algobootstrap", "_metrics"} servicesRegexp, err := regexp.Compile("^(_algobootstrap|_metrics)\\._tcp\\..*algodev.network$") @@ -355,7 +367,7 @@ func doDeleteDNS(network string, noPrompt bool, excludePattern string, includePa if includeRegex == nil || (includeRegex.MatchString(r.Name) && servicesRegexp.MatchString(r.Name)) { fmt.Printf("Found SRV record: %s\n", r.Name) - idsToDelete[r.ID] = r.Name + idsToDelete = append(idsToDelete, r) } } } @@ -367,7 +379,7 @@ func doDeleteDNS(network string, noPrompt bool, excludePattern string, includePa networkSuffix = "." + network + ".algodev.network" } - for _, recordType := range []string{"A", "CNAME"} { + for _, recordType := range []string{"A", "CNAME", "TXT"} { records, err := cloudflareDNS.ListDNSRecord(context.Background(), recordType, "", "", "", "", "") if err != nil { fmt.Fprintf(os.Stderr, "Error listing DNS '%s' entries: %v\n", recordType, err) @@ -384,21 +396,29 @@ func doDeleteDNS(network string, noPrompt bool, excludePattern string, includePa if includeRegex == nil || includeRegex.MatchString(r.Name) { fmt.Printf("Found DNS '%s' record: %s\n", recordType, r.Name) - idsToDelete[r.ID] = r.Name + idsToDelete = append(idsToDelete, r) } } } } - if len(idsToDelete) == 0 { + err = checkedDelete(idsToDelete, cloudflareDNS) + if err != nil { + fmt.Fprintf(os.Stderr, "Error deleting: %s\n", err) + } + return true +} + +func checkedDelete(toDelete []cloudflare.DNSRecordResponseEntry, cloudflareDNS *cloudflare.DNS) error { + if len(toDelete) == 0 { fmt.Printf("No DNS/SRV records found\n") - return true + return nil } var text string if !noPrompt { reader := bufio.NewReader(os.Stdin) - fmt.Printf("Delete these %d entries (type 'yes' to delete)? ", len(idsToDelete)) + fmt.Printf("Delete these %d entries (type 'yes' to delete)? ", len(toDelete)) text, _ = reader.ReadString('\n') text = strings.Replace(text, "\n", "", -1) } else { @@ -406,42 +426,59 @@ func doDeleteDNS(network string, noPrompt bool, excludePattern string, includePa } if text == "yes" { - for id, name := range idsToDelete { - fmt.Fprintf(os.Stdout, "Deleting %s\n", name) - err = cloudflareDNS.DeleteDNSRecord(context.Background(), id) + for _, entry := range toDelete { + fmt.Fprintf(os.Stdout, "Deleting %s\n", entry.Name) + err := cloudflareDNS.DeleteDNSRecord(context.Background(), entry.ID) if err != nil { - fmt.Fprintf(os.Stderr, " !! error deleting %s: %v\n", name, err) + return fmt.Errorf(" !! error deleting %s: %v", entry.Name, err) } } } - return true + return nil } -func listEntries(listNetwork string, recordType string) { +func getEntries(getNetwork string, recordType string) ([]cloudflare.DNSRecordResponseEntry, error) { + recordTypes := []string{"A", "CNAME", "SRV", "TXT"} + isKnown := false + for _, known := range append(recordTypes, "") { + if recordType == known { + isKnown = true + break + } + } + if !isKnown { + return nil, fmt.Errorf("invalid recordType specified %s", recordType) + } cfZoneID, cfToken, err := getClouldflareCredentials() if err != nil { - fmt.Fprintf(os.Stderr, "error getting DNS credentials: %v", err) - return + return nil, fmt.Errorf("error getting DNS credentials: %v", err) } cloudflareDNS := cloudflare.NewDNS(cfZoneID, cfToken) - recordTypes := []string{"A", "CNAME", "SRV"} if recordType != "" { recordTypes = []string{recordType} } + var records []cloudflare.DNSRecordResponseEntry for _, recType := range recordTypes { - records, err := cloudflareDNS.ListDNSRecord(context.Background(), recType, "", "", "", "", "") + records, err = cloudflareDNS.ListDNSRecord(context.Background(), recType, getNetwork, "", "", "", "") if err != nil { - fmt.Fprintf(os.Stderr, "Error listing DNS entries: %v\n", err) - os.Exit(1) + return nil, fmt.Errorf("error listing DNS entries %w", err) } + } + return records, nil +} - for _, record := range records { - if strings.HasSuffix(record.Name, listNetwork) { - fmt.Printf("%v\n", record.Name) - } +func listEntries(listNetwork string, recordType string) error { + records, err := getEntries("", recordType) + if err != nil { + return err + } + for _, record := range records { + if strings.HasSuffix(record.Name, listNetwork) { + fmt.Printf("%v\n", record.Name) } } + return nil } func doExportZone(network string, outputFilename string) bool { diff --git a/cmd/algons/dnsaddrCmd.go b/cmd/algons/dnsaddrCmd.go index 1df13cecfe..1d9189082b 100644 --- a/cmd/algons/dnsaddrCmd.go +++ b/cmd/algons/dnsaddrCmd.go @@ -17,16 +17,22 @@ package main import ( + "context" "fmt" + "os" + "github.com/multiformats/go-multiaddr" "github.com/spf13/cobra" "github.com/algorand/go-algorand/network/p2p/dnsaddr" + "github.com/algorand/go-algorand/tools/network/cloudflare" ) var ( dnsaddrDomain string secure bool + cmdMultiaddrs []string + nodeSize int ) func init() { @@ -35,6 +41,17 @@ func init() { dnsaddrTreeCmd.Flags().StringVarP(&dnsaddrDomain, "domain", "d", "", "Top level domain") dnsaddrTreeCmd.MarkFlagRequired("domain") dnsaddrTreeCmd.Flags().BoolVarP(&secure, "secure", "s", true, "Enable dnssec") + + dnsaddrTreeCmd.AddCommand(dnsaddrTreeCreateCmd) + dnsaddrTreeCreateCmd.Flags().StringArrayVarP(&cmdMultiaddrs, "multiaddrs", "m", []string{}, "multiaddrs to add") + dnsaddrTreeCreateCmd.Flags().StringVarP(&dnsaddrDomain, "domain", "d", "", "Top level domain") + dnsaddrTreeCreateCmd.Flags().IntVarP(&nodeSize, "node-size", "n", 50, "Number of multiaddrs entries per TXT record") + dnsaddrTreeCreateCmd.MarkFlagRequired("domain") + dnsaddrTreeCreateCmd.MarkFlagRequired("multiaddrs") + + dnsaddrTreeCmd.AddCommand(dnsaddrTreeDeleteCmd) + dnsaddrTreeDeleteCmd.Flags().StringVarP(&dnsaddrDomain, "domain", "d", "", "Top level domain") + dnsaddrTreeDeleteCmd.MarkFlagRequired("domain") } var dnsaddrCmd = &cobra.Command{ @@ -63,3 +80,105 @@ var dnsaddrTreeCmd = &cobra.Command{ } }, } +var dnsaddrTreeDeleteCmd = &cobra.Command{ + Use: "delete", + Short: "Recursively resolves and deletes the dnsaddr entries of the given domain", + Long: "Recursively resolves and deletes the dnsaddr entries of the given domain", + Run: func(cmd *cobra.Command, args []string) { + addr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/dnsaddr/%s", dnsaddrDomain)) + if err != nil { + fmt.Printf("unable to construct multiaddr for %s : %v\n", dnsaddrDomain, err) + return + } + controller := dnsaddr.NewMultiaddrDNSResolveController(secure, "") + cfZoneID, cfToken, err := getClouldflareCredentials() + if err != nil { + fmt.Fprintf(os.Stderr, "error getting DNS credentials: %v", err) + return + } + cloudflareDNS := cloudflare.NewDNS(cfZoneID, cfToken) + var recordsToDelete []cloudflare.DNSRecordResponseEntry + err = dnsaddr.Iterate(addr, controller, func(entryFrom multiaddr.Multiaddr, entries []multiaddr.Multiaddr) error { + domain, _ := entryFrom.ValueForProtocol(multiaddr.P_DNSADDR) + name := fmt.Sprintf("_dnsaddr.%s", domain) + fmt.Printf("listing records for %s\n", name) + records, err0 := cloudflareDNS.ListDNSRecord(context.Background(), "TXT", name, "", "", "", "") + if err0 != nil { + fmt.Printf("erroring listing dns records for %s %s\n", domain, err) + return err + } + for _, record := range records { + fmt.Printf("found record to delete %s:%s\n", record.Name, record.Content) + recordsToDelete = append(recordsToDelete, record) + } + return nil + }) + if err != nil { + fmt.Printf("%s\n", err.Error()) + return + } + err = checkedDelete(recordsToDelete, cloudflareDNS) + if err != nil { + fmt.Printf("error deleting records: %s\n", err) + } + }, +} + +var dnsaddrTreeCreateCmd = &cobra.Command{ + Use: "create", + Short: "Creates a tree of entries containing the multiaddrs at the provided root domain", + Long: "Creates a tree of entries containing the multiaddrs at the provided root domain", + Run: func(cmd *cobra.Command, args []string) { + if len(cmdMultiaddrs) == 0 { + fmt.Printf("must provide multiaddrs to put in the DNS records") + return + } + // Generate the dnsaddr entries required for the full tree + var dnsaddrsTo []string + for i := 0; i < len(cmdMultiaddrs)/nodeSize; i++ { + dnsaddrsTo = append(dnsaddrsTo, fmt.Sprintf("%d%s", i, dnsaddrDomain)) + } + dnsaddrsFrom := []string{fmt.Sprintf("_dnsaddr.%s", dnsaddrDomain)} + entries, err := getEntries(dnsaddrsFrom[0], "TXT") + if err != nil { + fmt.Printf("failed fetching entries for %s\n", dnsaddrsFrom[0]) + os.Exit(1) + } + if len(entries) > 0 { + for _, entry := range entries { + fmt.Printf("found entry %s => %s\n", entry.Name, entry.Content) + } + fmt.Printf("found entries already existing at %s, bailing out\n", dnsaddrsFrom[0]) + os.Exit(1) + } + for _, addrTo := range dnsaddrsTo { + dnsaddrsFrom = append(dnsaddrsFrom, fmt.Sprintf("_dnsaddr.%s", addrTo)) + } + for _, from := range dnsaddrsFrom { + for i := 0; i < nodeSize; i++ { + if len(dnsaddrsTo) > 0 { + newDnsaddr := fmt.Sprintf("dnsaddr=/dnsaddr/%s", dnsaddrsTo[len(dnsaddrsTo)-1]) + fmt.Printf("writing %s => %s\n", from, newDnsaddr) + err := doAddTXT(from, newDnsaddr) + if err != nil { + fmt.Printf("failed writing dnsaddr entry %s: %s\n", newDnsaddr, err) + os.Exit(1) + } + dnsaddrsTo = dnsaddrsTo[:len(dnsaddrsTo)-1] + continue + } + newDnsaddr := fmt.Sprintf("dnsaddr=%s", cmdMultiaddrs[len(cmdMultiaddrs)-1]) + fmt.Printf("writing %s => %s\n", from, newDnsaddr) + err := doAddTXT(from, newDnsaddr) + if err != nil { + fmt.Printf("failed writing dns entry %s\n", err) + os.Exit(1) + } + cmdMultiaddrs = cmdMultiaddrs[:len(cmdMultiaddrs)-1] + if len(cmdMultiaddrs) == 0 { + return + } + } + } + }, +} diff --git a/network/p2p/dnsaddr/resolve.go b/network/p2p/dnsaddr/resolve.go index ad9f4e8b42..0e21a5704a 100644 --- a/network/p2p/dnsaddr/resolve.go +++ b/network/p2p/dnsaddr/resolve.go @@ -29,21 +29,13 @@ func isDnsaddr(maddr multiaddr.Multiaddr) bool { return first.Protocol().Code == multiaddr.P_DNSADDR } -// MultiaddrsFromResolver attempts to recurse through dnsaddrs starting at domain. -// Any further dnsaddrs will be looked up until all TXT records have been fetched, -// and the full list of resulting Multiaddrs is returned. -// It uses the MultiaddrDNSResolveController to cycle through DNS resolvers on failure. -func MultiaddrsFromResolver(domain string, controller *MultiaddrDNSResolveController) ([]multiaddr.Multiaddr, error) { +// Iterate runs through the resolvable dnsaddrs in the tree using the resolveController and invokes f for each dnsaddr node lookup +func Iterate(initial multiaddr.Multiaddr, controller *MultiaddrDNSResolveController, f func(dnsaddr multiaddr.Multiaddr, entries []multiaddr.Multiaddr) error) error { resolver := controller.Resolver() if resolver == nil { - return nil, errors.New("passed controller has no resolvers MultiaddrsFromResolver") - } - dnsaddr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/dnsaddr/%s", domain)) - if err != nil { - return nil, fmt.Errorf("unable to construct multiaddr for %s : %v", domain, err) + return errors.New("passed controller has no resolvers Iterate") } - var resolved []multiaddr.Multiaddr - var toResolve = []multiaddr.Multiaddr{dnsaddr} + var toResolve = []multiaddr.Multiaddr{initial} for resolver != nil && len(toResolve) > 0 { curr := toResolve[0] maddrs, resolveErr := resolver.Resolve(context.Background(), curr) @@ -51,18 +43,40 @@ func MultiaddrsFromResolver(domain string, controller *MultiaddrDNSResolveContro resolver = controller.NextResolver() // If we errored, and have exhausted all resolvers, just return if resolver == nil { - return resolved, resolveErr + return resolveErr } continue } for _, maddr := range maddrs { if isDnsaddr(maddr) { toResolve = append(toResolve, maddr) - } else { - resolved = append(resolved, maddr) } } + if err := f(curr, maddrs); err != nil { + return err + } toResolve = toResolve[1:] } - return resolved, nil + return nil +} + +// MultiaddrsFromResolver attempts to recurse through dnsaddrs starting at domain. +// Any further dnsaddrs will be looked up until all TXT records have been fetched, +// and the full list of resulting Multiaddrs is returned. +// It uses the MultiaddrDNSResolveController to cycle through DNS resolvers on failure. +func MultiaddrsFromResolver(domain string, controller *MultiaddrDNSResolveController) ([]multiaddr.Multiaddr, error) { + dnsaddr, err := multiaddr.NewMultiaddr(fmt.Sprintf("/dnsaddr/%s", domain)) + if err != nil { + return nil, fmt.Errorf("unable to construct multiaddr for %s : %v", domain, err) + } + var resolved []multiaddr.Multiaddr + err = Iterate(dnsaddr, controller, func(_ multiaddr.Multiaddr, entries []multiaddr.Multiaddr) error { + for _, maddr := range entries { + if !isDnsaddr(maddr) { + resolved = append(resolved, maddr) + } + } + return nil + }) + return resolved, err } diff --git a/network/p2p/dnsaddr/resolve_test.go b/network/p2p/dnsaddr/resolve_test.go index df564d5e92..2834a2a5e7 100644 --- a/network/p2p/dnsaddr/resolve_test.go +++ b/network/p2p/dnsaddr/resolve_test.go @@ -94,9 +94,9 @@ func TestMultiaddrsFromResolverDnsFailure(t *testing.T) { } // Fail on no resolver - maddrs, err := MultiaddrsFromResolver("", dnsaddrCont) + maddrs, err := MultiaddrsFromResolver("0.0.0.1", dnsaddrCont) assert.Empty(t, maddrs) - assert.ErrorContains(t, err, fmt.Sprintf("passed controller has no resolvers MultiaddrsFromResolver")) + assert.ErrorContains(t, err, fmt.Sprintf("passed controller has no resolvers Iterate")) resolver, _ := madns.NewResolver(madns.WithDefaultResolver(&failureResolver{})) dnsaddrCont = &MultiaddrDNSResolveController{