Skip to content

Commit

Permalink
feat: Add custom dns resolver (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
mohitsethia authored Oct 4, 2024
1 parent 4093d58 commit 784898b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 13 deletions.
3 changes: 3 additions & 0 deletions internal/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
// Options are the flags supported by the command line application.
type Options struct {
// Input flags.
//Custom DNS Resolver
DNSResolver string
// Protocol to use.
Protocol string
// Number of iterations. Zero means infinite.
Expand All @@ -31,6 +33,7 @@ type Options struct {

// Parse fulfills the command line flags provided by the user.
func (opts *Options) Parse() {
flag.StringVar(&opts.DNSResolver, "r", "", "DNS resolution server")
flag.StringVar(&opts.Protocol, "p", "", "Test only one protocol")
flag.UintVar(&opts.Count, "c", 0, "Number of iterations")
flag.DurationVar(
Expand Down
3 changes: 3 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ func main() {
if protocol == nil {
internal.Fatal(fmt.Errorf("unknown protocol: %s", opts.Protocol))
}
if opts.DNSResolver != "" {
protocol.WithDNSResolver(opts.DNSResolver)
}
protocols = []*pkg.Protocol{protocol}
}
logger.Info("Starting ...", "protocols", protocols, "count", opts.Count)
Expand Down
61 changes: 54 additions & 7 deletions pkg/protocols.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
package pkg

import (
"context"
"fmt"
"net"
"net/http"
"time"
)

// Protocols included in the library.
var Protocols = []*Protocol{
{ID: "http", Probe: httpProbe, RHost: RandomCaptivePortal},
{ID: "tcp", Probe: tcpProbe, RHost: RandomTCPServer},
{ID: "dns", Probe: dnsProbe, RHost: RandomDomain},
var Protocols []*Protocol

func init() {
httpProtocol := &Protocol{
ID: "http",
RHost: RandomCaptivePortal,
}
httpProtocol.Probe = func(domain string, timeout time.Duration) (string, error) {
return httpProtocol.httpProbe(domain, timeout)
}

tcpProtocol := &Protocol{
ID: "tcp",
RHost: RandomTCPServer,
}
tcpProtocol.Probe = func(domain string, timeout time.Duration) (string, error) {
return tcpProtocol.tcpProbe(domain, timeout)
}

dnsProtocol := &Protocol{
ID: "dns",
RHost: RandomDomain,
}
dnsProtocol.Probe = func(domain string, timeout time.Duration) (string, error) {
return dnsProtocol.dnsProbe(domain, timeout)
}

Protocols = []*Protocol{httpProtocol, tcpProtocol, dnsProtocol}
}

// Protocol defines a probe attempt.
Expand All @@ -22,6 +47,12 @@ type Protocol struct {
Probe func(rhost string, timeout time.Duration) (string, error)
// Function to create a random remote
RHost func() (string, error)
// customDNSResolver
dnsResolver string
}

func (p *Protocol) WithDNSResolver(dnsResolver string) {
p.dnsResolver = dnsResolver
}

// String returns an human-readable representation of the protocol.
Expand All @@ -43,7 +74,7 @@ func (p *Protocol) validate() error {
// Makes an HTTP request.
//
// The extra information is the status code.
func httpProbe(u string, timeout time.Duration) (string, error) {
func (p *Protocol) httpProbe(u string, timeout time.Duration) (string, error) {
cli := &http.Client{Timeout: timeout}
resp, err := cli.Get(u)
if err != nil {
Expand All @@ -59,7 +90,7 @@ func httpProbe(u string, timeout time.Duration) (string, error) {
// Makes a TCP request.
//
// The extra information is the local host/port.
func tcpProbe(hostPort string, timeout time.Duration) (string, error) {
func (p *Protocol) tcpProbe(hostPort string, timeout time.Duration) (string, error) {
conn, err := net.DialTimeout("tcp", hostPort, timeout)
if err != nil {
return "", fmt.Errorf("making request to %s: %w", hostPort, err)
Expand All @@ -75,7 +106,23 @@ func tcpProbe(hostPort string, timeout time.Duration) (string, error) {
//
// The extra information is the first resolved IP address.
// TODO(#31)
func dnsProbe(domain string, timeout time.Duration) (string, error) {
func (p *Protocol) dnsProbe(domain string, timeout time.Duration) (string, error) {
if p != nil && p.dnsResolver != "" {
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: timeout,
}
return d.DialContext(ctx, network, p.dnsResolver)
},
}
addr, err := r.LookupHost(context.Background(), domain)
if err != nil {
return "", fmt.Errorf("resolving %s: %w", domain, err)
}
return fmt.Sprintf(addr[0]), nil
}
addrs, err := net.LookupHost(domain)
if err != nil {
return "", fmt.Errorf("resolving %s: %w", domain, err)
Expand Down
12 changes: 6 additions & 6 deletions pkg/protocols_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestHttpProbe(t *testing.T) {
"returns the status code if the request is successful",
func(t *testing.T) {
u := url.URL{Scheme: "http", Host: server.Addr}
got, err := httpProbe(u.String(), tout)
got, err := (&Protocol{}).httpProbe(u.String(), tout)
if err != nil {
t.Fatal(err)
}
Expand All @@ -41,7 +41,7 @@ func TestHttpProbe(t *testing.T) {
)
t.Run("returns an error if the request fails", func(t *testing.T) {
u := url.URL{Scheme: "http", Host: "localhost"}
got, err := httpProbe(u.String(), 1)
got, err := (&Protocol{}).httpProbe(u.String(), 1)
if err == nil {
t.Fatal("got nil, want an error")
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func TestTcpProbe(t *testing.T) {
t.Run(
"returns the local host/port if the request is successful",
func(t *testing.T) {
got, err := tcpProbe(hostPort, tout)
got, err := (&Protocol{}).tcpProbe(hostPort, tout)
if err != nil {
t.Fatal(err)
}
Expand All @@ -113,7 +113,7 @@ func TestTcpProbe(t *testing.T) {
},
)
t.Run("returns an error if the request fails", func(t *testing.T) {
got, err := tcpProbe("localhost:80", 1)
got, err := (&Protocol{}).tcpProbe("localhost:80", 1)
if err == nil {
t.Fatal("got nil, want an error")
}
Expand Down Expand Up @@ -145,7 +145,7 @@ func TestDnsProbe(t *testing.T) {
t.Run(
"returns the first resolved IP address if the request is successful",
func(t *testing.T) {
got, err := dnsProbe("google.com", tout)
got, err := (&Protocol{}).dnsProbe("google.com", tout)
if err != nil {
t.Fatal(err)
}
Expand All @@ -156,7 +156,7 @@ func TestDnsProbe(t *testing.T) {
},
)
t.Run("returns an error if the request fails", func(t *testing.T) {
got, err := dnsProbe("invalid.aa", 1)
got, err := (&Protocol{}).dnsProbe("invalid.aa", 1)
if err == nil {
t.Fatal("got nil, want an error")
}
Expand Down

0 comments on commit 784898b

Please sign in to comment.