diff --git a/dns.py b/dns.py index 9bb52da..eab2575 100644 --- a/dns.py +++ b/dns.py @@ -16,7 +16,7 @@ # limitations under the License. import socket -from dnslib import DNSRecord, DNSHeader, RR, A +from dnslib import DNSRecord, DNSHeader, RR, A, RCODE import threading import docker @@ -300,7 +300,7 @@ def resolve_dnsA_to_ip(network_data, networks, domain): # DNS Server class DNSServer: - def __init__(self, ip="0.0.0.0", port=53): + def __init__(self, ip="0.0.0.0", port=53, upstream_dns="8.8.8.8"): print_debug(f"Initializing DNS server on {ip}:{port}") self.ip = ip self.port = port @@ -308,6 +308,7 @@ def __init__(self, ip="0.0.0.0", port=53): self.server.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096) # Increase buffer size self.server.bind((self.ip, self.port)) self.executor = ThreadPoolExecutor(max_workers=10) # Limit the number of concurrent workers + self.upstream_dns = upstream_dns print_debug(f"DNS server initialized on {self.ip}:{self.port}") def handle_request(self, data, addr): @@ -316,8 +317,6 @@ def handle_request(self, data, addr): print_debug(f"Received request from {addr}: {request.q.qname}") # Create DNS response header - reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=1, ra=0), q=request.q) - original_domain = str(request.q.qname) # Find the domain in the DNS_TABLE and return the corresponding IP address @@ -344,9 +343,33 @@ def handle_request(self, data, addr): # Resolve DNS A records to IP addresses dnsA_records = resolve_dnsA_to_ip(network_data, networks, domain) - for ip in dnsA_records: - reply.add_answer(RR(original_domain, rdata=A(ip))) - print_debug(f"Added response for {original_domain} -> {ip}") + if len(dnsA_records) == 0: + # fallback to upstream DNS server + try: + # Convert the request to binary format + query_data = request.pack() + + # Create a socket and send the query to the upstream server + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(2) + + # Send the request to the upstream DNS server + sock.sendto(query_data, (self.upstream_dns, 53)) + + # Receive the response from the upstream server + data, _ = sock.recvfrom(4096) + sock.close() + + # Parse the response + reply = DNSRecord.parse(data) + except Exception as e: + reply = request.reply() + reply.header.rcode = RCODE.SERVFAIL + else: + reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=1, ra=0), q=request.q) + for ip in dnsA_records: + reply.add_answer(RR(original_domain, rdata=A(ip))) + print_debug(f"Added response for {original_domain} -> {ip}") else: print_debug(f"IP {request_addr} not found in any Docker network.") @@ -399,7 +422,7 @@ def stop(self): bind_ip = os.environ['BIND_IP'] # Initialize DNS server - server = DNSServer(ip=bind_ip, port=53) + server = DNSServer(ip=bind_ip, port=53, upstream_dns=os.getenv('UPSTREAM_DNS', '8.8.8.8')) # Start the DNS server in the main thread server.start()