diff --git a/glide.lock b/glide.lock index e334b124..37c83d6c 100644 --- a/glide.lock +++ b/glide.lock @@ -45,4 +45,11 @@ imports: version: a408501be4d17ee978c04a618e7a1b22af058c0e subpackages: - unix +- name: github.com/benschw/srv-lb + version: c87a20cc531f1babdf710dc3db91dd69becb0e8b + subpackages: + - dns + - lb +- name: github.com/miekg/dns + version: 822ae18e7187e1bbde923a37081f6c1b8e9ba68a devImports: [] diff --git a/glide.yaml b/glide.yaml index 1e17b0ca..17049a67 100644 --- a/glide.yaml +++ b/glide.yaml @@ -5,3 +5,8 @@ import: - package: golang.org/x/net subpackages: - websocket +- package: github.com/benschw/srv-lb + subpackages: + - dns + - lb +- package: github.com/miekg/dns diff --git a/resolver/dns.go b/resolver/dns.go new file mode 100644 index 00000000..9626421e --- /dev/null +++ b/resolver/dns.go @@ -0,0 +1,41 @@ +package resolver + +import ( + "strings" + + "github.com/benschw/srv-lb/lb" +) + +// DNSConfig accepts an address and optional load balancer config +type DNSConfig struct { + Addr string + LbCfg *lb.Config +} + +// ResolveSrvAddr returns a load-balanced host:port based on DNS SRV lookup (or `addr` if not a hostname) +func ResolveSrvAddr(dnsCfg DNSConfig) (string, error) { + addr := dnsCfg.Addr + + if v := strings.Split(addr, ":"); len(v) < 2 { + var cfg *lb.Config + var err error + if dnsCfg.LbCfg == nil { + cfg, err = lb.DefaultConfig() + + if err != nil { + return addr, err + } + } else { + cfg = dnsCfg.LbCfg + } + + l := lb.New(cfg, addr) + resolvAddr, err := l.Next() + + if err == nil { + addr = resolvAddr.String() + } + } + + return addr, nil +} diff --git a/resolver/dns_test.go b/resolver/dns_test.go new file mode 100644 index 00000000..0e75285a --- /dev/null +++ b/resolver/dns_test.go @@ -0,0 +1,29 @@ +package resolver + +import ( + "testing" + + "github.com/benschw/srv-lb/lb" +) + +func TestSrvLookup(t *testing.T) { + lbCfg, _ := lb.DefaultConfig() + lbCfg.Strategy = MockStrategy + dnsCfg := DNSConfig{Addr: "foo.example.com", LbCfg: lbCfg} + + addr, _ := ResolveSrvAddr(dnsCfg) + + if addr != "1.2.3.4:1234" { + t.Error("expected address string of 1.2.3.4:1234, got:", addr) + } +} + +func TestIpPassthrough(t *testing.T) { + dnsCfg := DNSConfig{Addr: "10.0.0.1:1234"} + + addr, _ := ResolveSrvAddr(dnsCfg) + + if addr != "10.0.0.1:1234" { + t.Error("expected output to equal input (10.0.0.1:1234), got:", addr) + } +} diff --git a/resolver/lb_mock_strategy.go b/resolver/lb_mock_strategy.go new file mode 100644 index 00000000..499dc5c3 --- /dev/null +++ b/resolver/lb_mock_strategy.go @@ -0,0 +1,30 @@ +package resolver + +import ( + "github.com/benschw/srv-lb/dns" + "github.com/benschw/srv-lb/lb" +) + +// MockStrategy is used for testing DNS load balancing +const MockStrategy lb.StrategyType = "mock" + +// New creates a new instance of the load balancer +func New(lib dns.Lookup) lb.GenericLoadBalancer { + lb := new(MockClb) + lb.dnsLib = lib + return lb +} + +// MockClb contains the dnslib +type MockClb struct { + dnsLib dns.Lookup +} + +// Next gets the next server in the available nodes +func (lb *MockClb) Next(name string) (dns.Address, error) { + return dns.Address{Address: "1.2.3.4", Port: 1234}, nil +} + +func init() { + lb.RegisterStrategy(MockStrategy, New) +} diff --git a/transports/tcp/tcp.go b/transports/tcp/tcp.go index 84d2ade9..60339e02 100644 --- a/transports/tcp/tcp.go +++ b/transports/tcp/tcp.go @@ -4,6 +4,7 @@ import ( "net" "github.com/gliderlabs/logspout/adapters/raw" + "github.com/gliderlabs/logspout/resolver" "github.com/gliderlabs/logspout/router" ) @@ -21,7 +22,12 @@ func rawTCPAdapter(route *router.Route) (router.LogAdapter, error) { type tcpTransport int func (t *tcpTransport) Dial(addr string, options map[string]string) (net.Conn, error) { - raddr, err := net.ResolveTCPAddr("tcp", addr) + daddr, err := resolver.ResolveSrvAddr(resolver.DNSConfig{Addr: addr}) + if err != nil { + return nil, err + } + + raddr, err := net.ResolveTCPAddr("tcp", daddr) if err != nil { return nil, err } diff --git a/transports/tls/tls.go b/transports/tls/tls.go index 4844b0a5..9495d4cc 100644 --- a/transports/tls/tls.go +++ b/transports/tls/tls.go @@ -5,6 +5,7 @@ import ( "net" "github.com/gliderlabs/logspout/adapters/raw" + "github.com/gliderlabs/logspout/resolver" "github.com/gliderlabs/logspout/router" ) @@ -22,7 +23,12 @@ func rawTLSAdapter(route *router.Route) (router.LogAdapter, error) { type tlsTransport int func (t *tlsTransport) Dial(addr string, options map[string]string) (net.Conn, error) { - conn, err := tls.Dial("tcp", addr, nil) + daddr, err := resolver.ResolveSrvAddr(resolver.DNSConfig{Addr: addr}) + if err != nil { + return nil, err + } + + conn, err := tls.Dial("tcp", daddr, nil) if err != nil { return nil, err } diff --git a/transports/udp/udp.go b/transports/udp/udp.go index b960c389..08f133ad 100644 --- a/transports/udp/udp.go +++ b/transports/udp/udp.go @@ -4,6 +4,7 @@ import ( "net" "github.com/gliderlabs/logspout/adapters/raw" + "github.com/gliderlabs/logspout/resolver" "github.com/gliderlabs/logspout/router" ) @@ -26,7 +27,12 @@ func rawUDPAdapter(route *router.Route) (router.LogAdapter, error) { type udpTransport int func (t *udpTransport) Dial(addr string, options map[string]string) (net.Conn, error) { - raddr, err := net.ResolveUDPAddr("udp", addr) + daddr, err := resolver.ResolveSrvAddr(resolver.DNSConfig{Addr: addr}) + if err != nil { + return nil, err + } + + raddr, err := net.ResolveUDPAddr("udp", daddr) if err != nil { return nil, err }