forked from slackhq/nebula
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdns_server.go
125 lines (108 loc) · 2.63 KB
/
dns_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package nebula
import (
"fmt"
"net"
"strconv"
"sync"
"github.com/miekg/dns"
)
// This whole thing should be rewritten to use context
var dnsR *dnsRecords
type dnsRecords struct {
sync.RWMutex
dnsMap map[string]string
hostMap *HostMap
}
func newDnsRecords(hostMap *HostMap) *dnsRecords {
return &dnsRecords{
dnsMap: make(map[string]string),
hostMap: hostMap,
}
}
func (d *dnsRecords) Query(data string) string {
d.RLock()
if r, ok := d.dnsMap[data]; ok {
d.RUnlock()
return r
}
d.RUnlock()
return ""
}
func (d *dnsRecords) QueryCert(data string) string {
ip := net.ParseIP(data[:len(data)-1])
if ip == nil {
return ""
}
iip := ip2int(ip)
hostinfo, err := d.hostMap.QueryVpnIP(iip)
if err != nil {
return ""
}
q := hostinfo.GetCert()
if q == nil {
return ""
}
cert := q.Details
c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAFter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
return c
}
func (d *dnsRecords) Add(host, data string) {
d.Lock()
d.dnsMap[host] = data
d.Unlock()
}
func parseQuery(m *dns.Msg, w dns.ResponseWriter) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
l.Debugf("Query for A %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b := net.ParseIP(a)
// We don't answer these queries from non nebula nodes or localhost
//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
return
}
l.Debugf("Query for TXT %s", q.Name)
ip := dnsR.QueryCert(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
}
func handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(m, w)
}
w.WriteMsg(m)
}
func dnsMain(hostMap *HostMap) {
dnsR = newDnsRecords(hostMap)
// attach request handler func
dns.HandleFunc(".", handleDnsRequest)
// start server
port := 53
server := &dns.Server{Addr: ":" + strconv.Itoa(port), Net: "udp"}
l.Debugf("Starting DNS responder at %d\n", port)
err := server.ListenAndServe()
defer server.Shutdown()
if err != nil {
l.Errorf("Failed to start server: %s\n ", err.Error())
}
}