diff --git a/coredns/plugin/handler_test.go b/coredns/plugin/handler_test.go index 30845624..0e4e0d60 100644 --- a/coredns/plugin/handler_test.go +++ b/coredns/plugin/handler_test.go @@ -64,6 +64,18 @@ var ( } port2 = mcsv1a1.ServicePort{ + Name: "udp", + Protocol: v1.ProtocolUDP, + Port: 42, + } + + port3 = mcsv1a1.ServicePort{ + Name: "tcp", + Protocol: v1.ProtocolTCP, + Port: 42, + } + + port4 = mcsv1a1.ServicePort{ Name: "dns", Protocol: v1.ProtocolUDP, Port: 53, @@ -886,14 +898,18 @@ func testSRVMultiplePorts() { t.lh.Resolver.PutServiceImport(newServiceImport(namespace1, service1, mcsv1a1.ClusterSetIP)) - t.lh.Resolver.PutEndpointSlices(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1, port2}, + t.lh.Resolver.PutEndpointSlices(newEndpointSlice(namespace1, service1, clusterID, []mcsv1a1.ServicePort{port1, port2, port3}, newEndpoint(endpointIP, "", true))) + t.lh.Resolver.PutEndpointSlices(newEndpointSlice(namespace1, service1, clusterID2, + []mcsv1a1.ServicePort{port1, port2, port3, port4}, + newEndpoint(serviceIP2, "", true))) + rec = dnstest.NewRecorder(&test.ResponseWriter{}) }) Context("a DNS query of type SRV", func() { - Specify("without a port name should return all the ports", func() { + Specify("without a port name should return all the unique ports", func() { qname := fmt.Sprintf("%s.%s.svc.clusterset.local.", service1, namespace1) t.executeTestCase(rec, test.Case{ @@ -929,9 +945,21 @@ func testSRVMultiplePorts() { test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, port2.Port, service1, namespace1)), }, }) + + qname = fmt.Sprintf("%s.%s.%s.%s.svc.clusterset.local.", port3.Name, port3.Protocol, service1, namespace1) + + t.executeTestCase(rec, test.Case{ + Qname: qname, + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s.%s.svc.clusterset.local.", qname, + port3.Port, service1, namespace1)), + }, + }) }) - Specify("with a DNS cluster name requested should return all the ports from the cluster", func() { + Specify("with a DNS cluster name requested should return all the unique ports from the cluster", func() { qname := fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID, service1, namespace1) t.executeTestCase(rec, test.Case{ @@ -943,6 +971,19 @@ func testSRVMultiplePorts() { test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), }, }) + + qname = fmt.Sprintf("%s.%s.%s.svc.clusterset.local.", clusterID2, service1, namespace1) + + t.executeTestCase(rec, test.Case{ + Qname: qname, + Qtype: dns.TypeSRV, + Rcode: dns.RcodeSuccess, + Answer: []dns.RR{ + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port2.Port, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port4.Port, qname)), + test.SRV(fmt.Sprintf("%s 5 IN SRV 0 50 %d %s", qname, port1.Port, qname)), + }, + }) }) Specify("with a port name requested with underscore prefix should return the port", func() { diff --git a/coredns/plugin/record.go b/coredns/plugin/record.go index e7afdcb9..8c43023d 100644 --- a/coredns/plugin/record.go +++ b/coredns/plugin/record.go @@ -25,6 +25,7 @@ import ( "github.com/coredns/coredns/request" "github.com/miekg/dns" "github.com/submariner-io/lighthouse/coredns/resolver" + "k8s.io/utils/set" "sigs.k8s.io/mcs-api/pkg/apis/v1alpha1" ) @@ -83,7 +84,15 @@ func (lh *Lighthouse) createSRVRecords(dnsrecords []resolver.DNSRecord, state *r target = dnsRecord.HostName + "." + target } + portsSeen := set.New[int32]() + for _, port := range reqPorts { + if portsSeen.Has(port.Port) { + continue + } + + portsSeen.Insert(port.Port) + record := &dns.SRV{ Hdr: dns.RR_Header{Name: state.QName(), Rrtype: dns.TypeSRV, Class: state.QClass(), Ttl: lh.TTL}, Priority: 0, @@ -91,6 +100,7 @@ func (lh *Lighthouse) createSRVRecords(dnsrecords []resolver.DNSRecord, state *r Port: uint16(port.Port), //nolint:gosec // Need to ignore integer conversion error Target: target, } + records = append(records, record) } }