From dde5c37272cdb189bf035d7f115f78a53d06fd90 Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Mon, 6 Aug 2018 13:17:17 -0400 Subject: [PATCH 1/6] :poop: Hack for more efficient parser --- server.go | 265 ++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 250 insertions(+), 15 deletions(-) diff --git a/server.go b/server.go index 66e2a00..869edf1 100644 --- a/server.go +++ b/server.go @@ -10,9 +10,11 @@ import ( "sync" "time" + "encoding/binary" + "errors" + "github.com/miekg/dns" "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" ) var ( @@ -94,7 +96,7 @@ func Register(instance, service, domain string, port int, text []string, iface * } s.service = entry - go s.mainloop() + //go s.mainloop() go s.probe() return s, nil @@ -178,14 +180,14 @@ func newServer(iface *net.Interface) (*Server, error) { // Join multicast groups to receive announcements p1 := ipv4.NewPacketConn(ipv4conn) - p2 := ipv6.NewPacketConn(ipv6conn) + //p2 := ipv6.NewPacketConn(ipv6conn) if iface != nil { if err := p1.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { return nil, err } - if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - return nil, err - } + //if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + // return nil, err + //} } else { ifaces, err := net.Interfaces() if err != nil { @@ -196,9 +198,9 @@ func newServer(iface *net.Interface) (*Server, error) { if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { errCount1++ } - if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - errCount2++ - } + //if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + // errCount2++ + //} } if len(ifaces) == errCount1 && len(ifaces) == errCount2 { return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") @@ -207,7 +209,7 @@ func newServer(iface *net.Interface) (*Server, error) { s := &Server{ ipv4conn: ipv4conn, - ipv6conn: ipv6conn, + //ipv6conn: ipv6conn, ttl: 3200, } @@ -282,12 +284,202 @@ func (s *Server) recv(c *net.UDPConn) { func (s *Server) parsePacket(packet []byte, from net.Addr) error { var msg dns.Msg if err := msg.Unpack(packet); err != nil { + //if err := UnpackDnsMsg(&msg, packet); err != nil { log.Printf("[ERR] bonjour: Failed to unpack packet: %v", err) return err } return s.handleQuery(&msg, from) } + +func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) { + if off+2 > len(msg) { + return 0, len(msg), errors.New("overflow") + } + return binary.BigEndian.Uint16(msg[off:]), off + 2, nil +} + +func unpackMsgHdr(msg []byte, off int) (dns.Header, int, error) { + var ( + dh dns.Header + err error + ) + dh.Id, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Bits, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Qdcount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Ancount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Nscount, off, err = unpackUint16(msg, off) + if err != nil { + return dh, off, err + } + dh.Arcount, off, err = unpackUint16(msg, off) + return dh, off, err +} + +const ( + headerSize = 12 + + // Header.Bits + _QR = 1 << 15 // query/response (response=1) + _AA = 1 << 10 // authoritative + _TC = 1 << 9 // truncated + _RD = 1 << 8 // recursion desired + _RA = 1 << 7 // recursion available + _Z = 1 << 6 // Z + _AD = 1 << 5 // authticated data + _CD = 1 << 4 // checking disabled +) + +func unpackQuestion(msg []byte, off int) (dns.Question, int, error) { + var ( + q dns.Question + err error + ) + q.Name, off, err = dns.UnpackDomainName(msg, off) + if err != nil { + return q, off, err + } + if off == len(msg) { + return q, off, nil + } + q.Qtype, off, err = unpackUint16(msg, off) + if err != nil { + return q, off, err + } + if off == len(msg) { + return q, off, nil + } + q.Qclass, off, err = unpackUint16(msg, off) + if off == len(msg) { + return q, off, nil + } + return q, off, err +} + +// unpackRRslice unpacks msg[off:] into an []RR. +// If we cannot unpack the whole array, then it will return nil +func unpackRRslice(l int, msg []byte, off int) (dst1 []dns.RR, off1 int, err error) { + var r dns.RR + // Don't pre-allocate, l may be under attacker control + var dst []dns.RR + for i := 0; i < l; i++ { + off1 := off + r, off, err = dns.UnpackRR(msg, off) + //println("RR slice", r.String()) + if err != nil { + off = len(msg) + break + } + // If offset does not increase anymore, l is a lie + if off1 == off { + l = i + break + } + dst = append(dst, r) + } + if err != nil && off == len(msg) { + dst = nil + } + return dst, off, err +} + +// Unpack unpacks a binary message to a Msg structure. +func UnpackDnsMsg(m *dns.Msg, msg []byte) (err error) { + var ( + dh dns.Header + off int + ) + if dh, off, err = unpackMsgHdr(msg, off); err != nil { + return err + } + + m.Id = dh.Id + m.Response = (dh.Bits & _QR) != 0 + m.Opcode = int(dh.Bits>>11) & 0xF + m.Authoritative = (dh.Bits & _AA) != 0 + m.Truncated = (dh.Bits & _TC) != 0 + m.RecursionDesired = (dh.Bits & _RD) != 0 + m.RecursionAvailable = (dh.Bits & _RA) != 0 + m.Zero = (dh.Bits & _Z) != 0 + m.AuthenticatedData = (dh.Bits & _AD) != 0 + m.CheckingDisabled = (dh.Bits & _CD) != 0 + m.Rcode = int(dh.Bits & 0xF) + + // If we are at the end of the message we should return *just* the + // header. This can still be useful to the caller. 9.9.9.9 sends these + // when responding with REFUSED for instance. + if off == len(msg) { + // reset sections before returning + m.Question, m.Answer, m.Ns, m.Extra = nil, nil, nil, nil + return nil + } + + // Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are + // attacker controlled. This means we can't use them to pre-allocate + // slices. + m.Question = nil + for i := 0; i < int(dh.Qdcount); i++ { + off1 := off + var q dns.Question + q, off, err = unpackQuestion(msg, off) + //println("Unpack question ", q.Name, q.Qclass, q.Qtype) + if q.Qtype != dns.TypePTR { + //println("Skipping type", q.Qtype, dns.TypePTR) + return nil + } + if err != nil { + // Even if Truncated is set, we only will set ErrTruncated if we + // actually got the questions + return err + } + if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie! + dh.Qdcount = uint16(i) + break + } + m.Question = append(m.Question, q) + } + + m.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off) + //// The header counts might have been wrong so we need to update it + dh.Ancount = uint16(len(m.Answer)) + if err == nil { + m.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off) + } + //println("Answer count", dh.Ancount) + //// The header counts might have been wrong so we need to update it + //dh.Nscount = uint16(len(m.Ns)) + //if err == nil { + // m.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off) + //} + dh.Nscount = 0 + // The header counts might have been wrong so we need to update it + dh.Arcount = uint16(len(m.Extra)) + + + if off != len(msg) { + // TODO(miek) make this an error? + // use PackOpt to let people tell how detailed the error reporting should be? + // println("m: extra bytes in m packet", off, "<", len(msg)) + } else if m.Truncated { + // Whether we ran into a an error or not, we want to return that it + // was truncated + err = dns.ErrTruncated + } + return err +} + // handleQuery is used to handle an incoming query func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { // Ignore answer for now @@ -341,7 +533,19 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg) error { return nil } + println("Match", q.Name) + println("Match", q.Name, "hostname", s.service.HostName) + //println("Match", q.Name, "to", s.service.ServiceName()) + //println("MatchInstance", q.Name, "to", s.service.ServiceInstanceName()) + //println("MatchService", q.Name, "to", s.service.ServiceTypeName()) + if (q.Name == s.service.HostName) { + println("Match", q.Name, "MATCHED MATCHED") + } + switch q.Name { + case s.service.HostName: + println("Match", q.Name, "Matched to host", s.service.HostName) + s.composeHostAnswers(resp, s.ttl) case s.service.ServiceName(): s.composeBrowsingAnswers(resp, s.ttl) case s.service.ServiceInstanceName(): @@ -353,6 +557,35 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg) error { return nil } +func (s *Server) composeHostAnswers(resp *dns.Msg, ttl uint32) { + if s.service.AddrIPv4 != nil { + a := &dns.A{ + Hdr: dns.RR_Header{ + Name: s.service.HostName, + Rrtype: dns.TypeA, + Class: dns.ClassINET | dns.TypeTA, + Ttl: ttl, + }, + A: s.service.AddrIPv4, + } + resp.Answer = append(resp.Answer, a) + } + if s.service.AddrIPv6 != nil { + aaaa := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: s.service.HostName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET | dns.TypeTA, + Ttl: ttl, + }, + AAAA: s.service.AddrIPv6, + } + resp.Answer = append(resp.Answer, aaaa) + } + + resp.Question = make([]dns.Question, 0) +} + func (s *Server) composeBrowsingAnswers(resp *dns.Msg, ttl uint32) { ptr := &dns.PTR{ Hdr: dns.RR_Header{ @@ -559,12 +792,14 @@ func (s *Server) probe() { // provided that the interval between unsolicited responses increases by // at least a factor of two with every response sent. timeout := 1 * time.Second - for i := 0; i < 3; i++ { - if err := s.multicastResponse(resp); err != nil { - log.Println("[ERR] bonjour: failed to send announcement:", err.Error()) + for !s.shouldShutdown { + for i := 0; i < 3 && !s.shouldShutdown; i++ { + if err := s.multicastResponse(resp); err != nil { + log.Println("[ERR] bonjour: failed to send announcement:", err.Error()) + } + time.Sleep(timeout) + timeout *= 2 } - time.Sleep(timeout) - timeout *= 2 } } From 1412c0312cde2270fe1b7dedc377a44ac7cd02af Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Mon, 6 Aug 2018 16:33:15 -0400 Subject: [PATCH 2/6] :sparkles: Faster mainloop --- server.go | 264 ++++++++++++------------------------------------------ 1 file changed, 56 insertions(+), 208 deletions(-) diff --git a/server.go b/server.go index 869edf1..1a3bebd 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,11 @@ import ( "net" "os" "strings" + "runtime" "sync" + "syscall" "time" - "encoding/binary" - "errors" - "github.com/miekg/dns" "golang.org/x/net/ipv4" ) @@ -96,7 +95,9 @@ func Register(instance, service, domain string, port int, text []string, iface * } s.service = entry - //go s.mainloop() + s.service.HostName = strings.ToLower(s.service.HostName) + + go s.mainloop() go s.probe() return s, nil @@ -147,6 +148,8 @@ func RegisterProxy(instance, service, domain string, port int, host, ip string, } s.service = entry + s.service.HostName = strings.ToLower(s.service.HostName) + go s.mainloop() go s.probe() @@ -210,7 +213,7 @@ func newServer(iface *net.Interface) (*Server, error) { s := &Server{ ipv4conn: ipv4conn, //ipv6conn: ipv6conn, - ttl: 3200, + ttl: 3200, } return s, nil @@ -268,14 +271,52 @@ func (s *Server) recv(c *net.UDPConn) { if c == nil { return } + + f, _ := c.File() + fd, _ := syscall.Dup(int(f.Fd())) + + // duplicating the fd because c.fd which holds the actual one gets garbage collected + e := syscall.SetNonblock(fd, true) + if e != nil { + println("Could not set nonblock", e) + } + c.SetReadBuffer(4000000) + if e != nil { + println("Could not set buffer size", e) + } buf := make([]byte, 65536) + + i := 0 + bytes := 0 + for !s.shouldShutdown { - n, from, err := c.ReadFrom(buf) + n, sa, err := syscall.Recvfrom(fd, buf, 0) + var from *net.UDPAddr + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + from = &net.UDPAddr{IP: sa.Addr[0:], Port: sa.Port} + case *syscall.SockaddrInet6: + from = &net.UDPAddr{IP: sa.Addr[0:], Port: sa.Port} + } + if err != nil { + n = 0 + } + + if n == 0 { + time.Sleep(50 * time.Millisecond) + i = 0 + bytes = 0 + continue } + + + i++ + bytes += n + if err := s.parsePacket(buf[:n], from); err != nil { - log.Printf("[ERR] bonjour: Failed to handle query: %v", err) + //log.Printf("[ERR] bonjour: Failed to handle query: %v", err) } } } @@ -284,200 +325,12 @@ func (s *Server) recv(c *net.UDPConn) { func (s *Server) parsePacket(packet []byte, from net.Addr) error { var msg dns.Msg if err := msg.Unpack(packet); err != nil { - //if err := UnpackDnsMsg(&msg, packet); err != nil { - log.Printf("[ERR] bonjour: Failed to unpack packet: %v", err) - return err - } - return s.handleQuery(&msg, from) -} - - -func unpackUint16(msg []byte, off int) (i uint16, off1 int, err error) { - if off+2 > len(msg) { - return 0, len(msg), errors.New("overflow") - } - return binary.BigEndian.Uint16(msg[off:]), off + 2, nil -} - -func unpackMsgHdr(msg []byte, off int) (dns.Header, int, error) { - var ( - dh dns.Header - err error - ) - dh.Id, off, err = unpackUint16(msg, off) - if err != nil { - return dh, off, err - } - dh.Bits, off, err = unpackUint16(msg, off) - if err != nil { - return dh, off, err - } - dh.Qdcount, off, err = unpackUint16(msg, off) - if err != nil { - return dh, off, err - } - dh.Ancount, off, err = unpackUint16(msg, off) - if err != nil { - return dh, off, err - } - dh.Nscount, off, err = unpackUint16(msg, off) - if err != nil { - return dh, off, err - } - dh.Arcount, off, err = unpackUint16(msg, off) - return dh, off, err -} - -const ( - headerSize = 12 - - // Header.Bits - _QR = 1 << 15 // query/response (response=1) - _AA = 1 << 10 // authoritative - _TC = 1 << 9 // truncated - _RD = 1 << 8 // recursion desired - _RA = 1 << 7 // recursion available - _Z = 1 << 6 // Z - _AD = 1 << 5 // authticated data - _CD = 1 << 4 // checking disabled -) - -func unpackQuestion(msg []byte, off int) (dns.Question, int, error) { - var ( - q dns.Question - err error - ) - q.Name, off, err = dns.UnpackDomainName(msg, off) - if err != nil { - return q, off, err - } - if off == len(msg) { - return q, off, nil - } - q.Qtype, off, err = unpackUint16(msg, off) - if err != nil { - return q, off, err - } - if off == len(msg) { - return q, off, nil - } - q.Qclass, off, err = unpackUint16(msg, off) - if off == len(msg) { - return q, off, nil - } - return q, off, err -} - -// unpackRRslice unpacks msg[off:] into an []RR. -// If we cannot unpack the whole array, then it will return nil -func unpackRRslice(l int, msg []byte, off int) (dst1 []dns.RR, off1 int, err error) { - var r dns.RR - // Don't pre-allocate, l may be under attacker control - var dst []dns.RR - for i := 0; i < l; i++ { - off1 := off - r, off, err = dns.UnpackRR(msg, off) - //println("RR slice", r.String()) - if err != nil { - off = len(msg) - break + if err != dns.ErrTruncated { + log.Printf("[ERR] bonjour: Failed to unpack packet: %v", err) } - // If offset does not increase anymore, l is a lie - if off1 == off { - l = i - break - } - dst = append(dst, r) - } - if err != nil && off == len(msg) { - dst = nil - } - return dst, off, err -} - -// Unpack unpacks a binary message to a Msg structure. -func UnpackDnsMsg(m *dns.Msg, msg []byte) (err error) { - var ( - dh dns.Header - off int - ) - if dh, off, err = unpackMsgHdr(msg, off); err != nil { return err } - - m.Id = dh.Id - m.Response = (dh.Bits & _QR) != 0 - m.Opcode = int(dh.Bits>>11) & 0xF - m.Authoritative = (dh.Bits & _AA) != 0 - m.Truncated = (dh.Bits & _TC) != 0 - m.RecursionDesired = (dh.Bits & _RD) != 0 - m.RecursionAvailable = (dh.Bits & _RA) != 0 - m.Zero = (dh.Bits & _Z) != 0 - m.AuthenticatedData = (dh.Bits & _AD) != 0 - m.CheckingDisabled = (dh.Bits & _CD) != 0 - m.Rcode = int(dh.Bits & 0xF) - - // If we are at the end of the message we should return *just* the - // header. This can still be useful to the caller. 9.9.9.9 sends these - // when responding with REFUSED for instance. - if off == len(msg) { - // reset sections before returning - m.Question, m.Answer, m.Ns, m.Extra = nil, nil, nil, nil - return nil - } - - // Qdcount, Ancount, Nscount, Arcount can't be trusted, as they are - // attacker controlled. This means we can't use them to pre-allocate - // slices. - m.Question = nil - for i := 0; i < int(dh.Qdcount); i++ { - off1 := off - var q dns.Question - q, off, err = unpackQuestion(msg, off) - //println("Unpack question ", q.Name, q.Qclass, q.Qtype) - if q.Qtype != dns.TypePTR { - //println("Skipping type", q.Qtype, dns.TypePTR) - return nil - } - if err != nil { - // Even if Truncated is set, we only will set ErrTruncated if we - // actually got the questions - return err - } - if off1 == off { // Offset does not increase anymore, dh.Qdcount is a lie! - dh.Qdcount = uint16(i) - break - } - m.Question = append(m.Question, q) - } - - m.Answer, off, err = unpackRRslice(int(dh.Ancount), msg, off) - //// The header counts might have been wrong so we need to update it - dh.Ancount = uint16(len(m.Answer)) - if err == nil { - m.Ns, off, err = unpackRRslice(int(dh.Nscount), msg, off) - } - //println("Answer count", dh.Ancount) - //// The header counts might have been wrong so we need to update it - //dh.Nscount = uint16(len(m.Ns)) - //if err == nil { - // m.Extra, off, err = unpackRRslice(int(dh.Arcount), msg, off) - //} - dh.Nscount = 0 - // The header counts might have been wrong so we need to update it - dh.Arcount = uint16(len(m.Extra)) - - - if off != len(msg) { - // TODO(miek) make this an error? - // use PackOpt to let people tell how detailed the error reporting should be? - // println("m: extra bytes in m packet", off, "<", len(msg)) - } else if m.Truncated { - // Whether we ran into a an error or not, we want to return that it - // was truncated - err = dns.ErrTruncated - } - return err + return s.handleQuery(&msg, from) } // handleQuery is used to handle an incoming query @@ -533,18 +386,13 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg) error { return nil } - println("Match", q.Name) - println("Match", q.Name, "hostname", s.service.HostName) - //println("Match", q.Name, "to", s.service.ServiceName()) - //println("MatchInstance", q.Name, "to", s.service.ServiceInstanceName()) - //println("MatchService", q.Name, "to", s.service.ServiceTypeName()) - if (q.Name == s.service.HostName) { - println("Match", q.Name, "MATCHED MATCHED") - } + var name = strings.ToLower(q.Name) + + println("Match", name, "service", s.service.ServiceName()) + println("Match", name, "instance", s.service.ServiceInstanceName()) - switch q.Name { + switch name { case s.service.HostName: - println("Match", q.Name, "Matched to host", s.service.HostName) s.composeHostAnswers(resp, s.ttl) case s.service.ServiceName(): s.composeBrowsingAnswers(resp, s.ttl) From 0cbed9819468d65943487a295215a52cee0af485 Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Mon, 6 Aug 2018 16:41:10 -0400 Subject: [PATCH 3/6] :sparkles: Clear the question and answer regularly --- server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index 1a3bebd..e0d2cc0 100644 --- a/server.go +++ b/server.go @@ -541,7 +541,7 @@ func (s *Server) composeLookupAnswers(resp *dns.Msg, ttl uint32) { }, Ptr: s.service.ServiceName(), } - resp.Answer = append(resp.Answer, srv, txt, ptr, dnssd) + resp.Answer = append(resp.Answer, ptr, txt, srv, dnssd) if s.service.AddrIPv4 != nil { a := &dns.A{ @@ -567,6 +567,7 @@ func (s *Server) composeLookupAnswers(resp *dns.Msg, ttl uint32) { } resp.Extra = append(resp.Extra, aaaa) } + resp.Question = make([]dns.Question, 0) } func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { @@ -639,8 +640,8 @@ func (s *Server) probe() { // packet loss, a responder MAY send up to eight unsolicited responses, // provided that the interval between unsolicited responses increases by // at least a factor of two with every response sent. - timeout := 1 * time.Second for !s.shouldShutdown { + timeout := 1 * time.Second for i := 0; i < 3 && !s.shouldShutdown; i++ { if err := s.multicastResponse(resp); err != nil { log.Println("[ERR] bonjour: failed to send announcement:", err.Error()) From 1ca7850155713b23b18a4853a1f27c3fbea3c86f Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Mon, 6 Aug 2018 17:32:04 -0400 Subject: [PATCH 4/6] :sparkles: Add support for multiple IP addresses --- client.go | 8 ++-- server.go | 129 +++++++++++++---------------------------------------- service.go | 16 +++---- 3 files changed, 44 insertions(+), 109 deletions(-) diff --git a/client.go b/client.go index 05ec013..7120841 100644 --- a/client.go +++ b/client.go @@ -202,14 +202,14 @@ func (c *client) mainloop(params *LookupParams) { entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl case *dns.A: for k, e := range entries { - if e.HostName == rr.Hdr.Name && entries[k].AddrIPv4 == nil { - entries[k].AddrIPv4 = rr.A + if e.HostName == rr.Hdr.Name { + entries[k].AddrsIPv4 = append(entries[k].AddrsIPv4, rr.A) } } case *dns.AAAA: for k, e := range entries { - if e.HostName == rr.Hdr.Name && entries[k].AddrIPv6 == nil { - entries[k].AddrIPv6 = rr.AAAA + if e.HostName == rr.Hdr.Name { + entries[k].AddrsIPv6 = append(entries[k].AddrsIPv6, rr.AAAA) } } } diff --git a/server.go b/server.go index e0d2cc0..dc90aa4 100644 --- a/server.go +++ b/server.go @@ -5,10 +5,8 @@ import ( "log" "math/rand" "net" - "os" - "strings" - "runtime" - "sync" + "strings" + "sync" "syscall" "time" @@ -42,70 +40,9 @@ var ( } ) -// Register a service by given arguments. This call will take the system's hostname -// and lookup IP by that hostname. -func Register(instance, service, domain string, port int, text []string, iface *net.Interface) (*Server, error) { - entry := NewServiceEntry(instance, service, domain) - entry.Port = port - entry.Text = text - - if entry.Instance == "" { - return nil, fmt.Errorf("Missing service instance name") - } - if entry.Service == "" { - return nil, fmt.Errorf("Missing service name") - } - if entry.Domain == "" { - entry.Domain = "local" - } - if entry.Port == 0 { - return nil, fmt.Errorf("Missing port") - } - - var err error - if entry.HostName == "" { - entry.HostName, err = os.Hostname() - if err != nil { - return nil, fmt.Errorf("Could not determine host") - } - } - entry.HostName = fmt.Sprintf("%s.", trimDot(entry.HostName)) - - addrs, err := net.LookupIP(entry.HostName) - if err != nil { - // Try appending the host domain suffix and lookup again - // (required for Linux-based hosts) - tmpHostName := fmt.Sprintf("%s%s.", entry.HostName, entry.Domain) - addrs, err = net.LookupIP(tmpHostName) - if err != nil { - return nil, fmt.Errorf("Could not determine host IP addresses for %s", entry.HostName) - } - } - for i := 0; i < len(addrs); i++ { - if ipv4 := addrs[i].To4(); ipv4 != nil { - entry.AddrIPv4 = addrs[i] - } else if ipv6 := addrs[i].To16(); ipv6 != nil { - entry.AddrIPv6 = addrs[i] - } - } - - s, err := newServer(iface) - if err != nil { - return nil, err - } - - s.service = entry - s.service.HostName = strings.ToLower(s.service.HostName) - - go s.mainloop() - go s.probe() - - return s, nil -} - // Register a service proxy by given argument. This call will skip the hostname/IP lookup and // will use the provided values. -func RegisterProxy(instance, service, domain string, port int, host, ip string, text []string, iface *net.Interface) (*Server, error) { +func RegisterProxy(instance, service, domain string, port int, host string, ips []net.IP, text []string, iface *net.Interface) (*Server, error) { entry := NewServiceEntry(instance, service, domain) entry.Port = port entry.Text = text @@ -131,15 +68,17 @@ func RegisterProxy(instance, service, domain string, port int, host, ip string, entry.HostName = fmt.Sprintf("%s.%s.", trimDot(entry.HostName), trimDot(entry.Domain)) } - ipAddr := net.ParseIP(ip) - if ipAddr == nil { - return nil, fmt.Errorf("Failed to parse given IP: %v", ip) - } else if ipv4 := ipAddr.To4(); ipv4 != nil { - entry.AddrIPv4 = ipAddr - } else if ipv6 := ipAddr.To16(); ipv6 != nil { - entry.AddrIPv4 = ipAddr - } else { - return nil, fmt.Errorf("The IP is neither IPv4 nor IPv6: %#v", ipAddr) + + for _, ipAddr := range(ips) { + if ipAddr == nil { + return nil, fmt.Errorf("Failed to parse given IP: %v", ipAddr) + } else if ipv4 := ipAddr.To4(); ipv4 != nil { + entry.AddrsIPv4 = append(entry.AddrsIPv4, ipv4) + } else if ipv6 := ipAddr.To16(); ipv6 != nil { + entry.AddrsIPv6 = append(entry.AddrsIPv6, ipv4) + } else { + return nil, fmt.Errorf("The IP is neither IPv4 nor IPv6: %#v", ipAddr) + } } s, err := newServer(iface) @@ -388,9 +327,6 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg) error { var name = strings.ToLower(q.Name) - println("Match", name, "service", s.service.ServiceName()) - println("Match", name, "instance", s.service.ServiceInstanceName()) - switch name { case s.service.HostName: s.composeHostAnswers(resp, s.ttl) @@ -406,31 +342,30 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg) error { } func (s *Server) composeHostAnswers(resp *dns.Msg, ttl uint32) { - if s.service.AddrIPv4 != nil { + for _, ipAddr := range s.service.AddrsIPv4 { a := &dns.A{ Hdr: dns.RR_Header{ Name: s.service.HostName, Rrtype: dns.TypeA, - Class: dns.ClassINET | dns.TypeTA, + Class: dns.ClassINET, Ttl: ttl, }, - A: s.service.AddrIPv4, + A: ipAddr, } resp.Answer = append(resp.Answer, a) } - if s.service.AddrIPv6 != nil { + for _, ipAddr := range s.service.AddrsIPv6 { aaaa := &dns.AAAA{ Hdr: dns.RR_Header{ Name: s.service.HostName, Rrtype: dns.TypeAAAA, - Class: dns.ClassINET | dns.TypeTA, + Class: dns.ClassINET, Ttl: ttl, }, - AAAA: s.service.AddrIPv6, + AAAA: ipAddr, } resp.Answer = append(resp.Answer, aaaa) } - resp.Question = make([]dns.Question, 0) } @@ -469,7 +404,7 @@ func (s *Server) composeBrowsingAnswers(resp *dns.Msg, ttl uint32) { } resp.Extra = append(resp.Extra, srv, txt) - if s.service.AddrIPv4 != nil { + for _, ipAddr := range s.service.AddrsIPv4 { a := &dns.A{ Hdr: dns.RR_Header{ Name: s.service.HostName, @@ -477,11 +412,11 @@ func (s *Server) composeBrowsingAnswers(resp *dns.Msg, ttl uint32) { Class: dns.ClassINET, Ttl: ttl, }, - A: s.service.AddrIPv4, + A: ipAddr, } resp.Extra = append(resp.Extra, a) } - if s.service.AddrIPv6 != nil { + for _, ipAddr := range s.service.AddrsIPv6 { aaaa := &dns.AAAA{ Hdr: dns.RR_Header{ Name: s.service.HostName, @@ -489,7 +424,7 @@ func (s *Server) composeBrowsingAnswers(resp *dns.Msg, ttl uint32) { Class: dns.ClassINET, Ttl: ttl, }, - AAAA: s.service.AddrIPv6, + AAAA: ipAddr, } resp.Extra = append(resp.Extra, aaaa) } @@ -543,27 +478,27 @@ func (s *Server) composeLookupAnswers(resp *dns.Msg, ttl uint32) { } resp.Answer = append(resp.Answer, ptr, txt, srv, dnssd) - if s.service.AddrIPv4 != nil { + for _, ipAddr := range s.service.AddrsIPv4 { a := &dns.A{ Hdr: dns.RR_Header{ Name: s.service.HostName, Rrtype: dns.TypeA, - Class: dns.ClassINET | cache_flush, - Ttl: 120, + Class: dns.ClassINET, + Ttl: ttl, }, - A: s.service.AddrIPv4, + A: ipAddr, } resp.Extra = append(resp.Extra, a) } - if s.service.AddrIPv6 != nil { + for _, ipAddr := range s.service.AddrsIPv6 { aaaa := &dns.AAAA{ Hdr: dns.RR_Header{ Name: s.service.HostName, Rrtype: dns.TypeAAAA, - Class: dns.ClassINET | cache_flush, - Ttl: 120, + Class: dns.ClassINET, + Ttl: ttl, }, - AAAA: s.service.AddrIPv6, + AAAA: ipAddr, } resp.Extra = append(resp.Extra, aaaa) } diff --git a/service.go b/service.go index 57c188e..55697db 100644 --- a/service.go +++ b/service.go @@ -76,12 +76,12 @@ func NewLookupParams(instance, service, domain string, entries chan<- *ServiceEn // used to answer multicast queries. type ServiceEntry struct { ServiceRecord - HostName string `json:"hostname"` // Host machine DNS name - Port int `json:"port"` // Service Port - Text []string `json:"text"` // Service info served as a TXT record - TTL uint32 `json:"ttl"` // TTL of the service record - AddrIPv4 net.IP `json:"-"` // Host machine IPv4 address - AddrIPv6 net.IP `json:"-"` // Host machine IPv6 address + HostName string `json:"hostname"` // Host machine DNS name + Port int `json:"port"` // Service Port + Text []string `json:"text"` // Service info served as a TXT record + TTL uint32 `json:"ttl"` // TTL of the service record + AddrsIPv4 []net.IP `json:"-"` // Host machine IPv4 address + AddrsIPv6 []net.IP `json:"-"` // Host machine IPv6 address } // Constructs a ServiceEntry structure by given arguments @@ -92,7 +92,7 @@ func NewServiceEntry(instance, service, domain string) *ServiceEntry { 0, []string{}, 0, - nil, - nil, + []net.IP{}, + []net.IP{}, } } From 36aebc51dc716ef94e9728424860d7550f3adfef Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Mon, 6 Aug 2018 17:33:02 -0400 Subject: [PATCH 5/6] :art: Whitespace --- server.go | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/server.go b/server.go index dc90aa4..dabe87c 100644 --- a/server.go +++ b/server.go @@ -5,8 +5,8 @@ import ( "log" "math/rand" "net" - "strings" - "sync" + "strings" + "sync" "syscall" "time" @@ -68,8 +68,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips entry.HostName = fmt.Sprintf("%s.%s.", trimDot(entry.HostName), trimDot(entry.Domain)) } - - for _, ipAddr := range(ips) { + for _, ipAddr := range (ips) { if ipAddr == nil { return nil, fmt.Errorf("Failed to parse given IP: %v", ipAddr) } else if ipv4 := ipAddr.To4(); ipv4 != nil { @@ -250,7 +249,6 @@ func (s *Server) recv(c *net.UDPConn) { continue } - i++ bytes += n From 5d30ee26ba51cc37bd3b452c3739e61e3b0e7ea6 Mon Sep 17 00:00:00 2001 From: Manuel Odendahl Date: Mon, 6 Aug 2018 17:36:00 -0400 Subject: [PATCH 6/6] :sparkles: Put back ipv6 support --- server.go | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/server.go b/server.go index dabe87c..43bdc66 100644 --- a/server.go +++ b/server.go @@ -12,6 +12,7 @@ import ( "github.com/miekg/dns" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" ) var ( @@ -120,15 +121,25 @@ func newServer(iface *net.Interface) (*Server, error) { } // Join multicast groups to receive announcements - p1 := ipv4.NewPacketConn(ipv4conn) - //p2 := ipv6.NewPacketConn(ipv6conn) + var p1 *ipv4.PacketConn + var p2 *ipv6.PacketConn + if ipv4conn != nil { + p1 = ipv4.NewPacketConn(ipv4conn) + } + if ipv6conn != nil { + p2 = ipv6.NewPacketConn(ipv6conn) + } if iface != nil { - if err := p1.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { - return nil, err + if p1 != nil { + if err := p1.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + return nil, err + } + } + if p2 != nil { + if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + return nil, err + } } - //if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - // return nil, err - //} } else { ifaces, err := net.Interfaces() if err != nil { @@ -136,12 +147,16 @@ func newServer(iface *net.Interface) (*Server, error) { } errCount1, errCount2 := 0, 0 for _, iface := range ifaces { - if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { - errCount1++ + if p1 != nil { + if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { + errCount1++ + } + } + if p2 != nil { + if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { + errCount2++ + } } - //if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { - // errCount2++ - //} } if len(ifaces) == errCount1 && len(ifaces) == errCount2 { return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") @@ -150,7 +165,7 @@ func newServer(iface *net.Interface) (*Server, error) { s := &Server{ ipv4conn: ipv4conn, - //ipv6conn: ipv6conn, + ipv6conn: ipv6conn, ttl: 3200, }