From cecb8534f61de16aaa1a1a84eaaeb14dfdf67448 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 16 Feb 2024 08:46:24 -0800 Subject: [PATCH] Further improve CVE fix coverage to 100% for sync and async. (cherry picked from commit a1a998938b7370dae41784f8bc0a841dc2addba9) --- tests/test_async.py | 184 +++++++++++++++++++++++++++++++++++++++++++- tests/test_query.py | 21 +++++ 2 files changed, 204 insertions(+), 1 deletion(-) diff --git a/tests/test_async.py b/tests/test_async.py index 4ea23015..ba2078cd 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -18,7 +18,6 @@ import asyncio import random import socket -import sys import time import unittest @@ -28,6 +27,7 @@ import dns.message import dns.name import dns.query +import dns.rcode import dns.rdataclass import dns.rdatatype import dns.resolver @@ -664,3 +664,185 @@ def async_run(self, afunc): except ImportError: pass + + +class MockSock: + def __init__(self, wire1, from1, wire2, from2): + self.family = socket.AF_INET + self.first_time = True + self.wire1 = wire1 + self.from1 = from1 + self.wire2 = wire2 + self.from2 = from2 + + async def sendto(self, data, where, timeout): + return len(data) + + async def recvfrom(self, bufsize, expiration): + if self.first_time: + self.first_time = False + return self.wire1, self.from1 + else: + return self.wire2, self.from2 + + +class IgnoreErrors(unittest.TestCase): + def setUp(self): + self.q = dns.message.make_query("example.", "A") + self.good_r = dns.message.make_response(self.q) + self.good_r.set_rcode(dns.rcode.NXDOMAIN) + self.good_r_wire = self.good_r.to_wire() + dns.asyncbackend.set_default_backend("asyncio") + + def async_run(self, afunc): + return asyncio.run(afunc()) + + async def mock_receive( + self, + wire1, + from1, + wire2, + from2, + ignore_unexpected=True, + ignore_errors=True, + ): + s = MockSock(wire1, from1, wire2, from2) + (r, when, _) = await dns.asyncquery.receive_udp( + s, + ("127.0.0.1", 53), + time.time() + 2, + ignore_unexpected=ignore_unexpected, + ignore_errors=ignore_errors, + query=self.q, + ) + self.assertEqual(r, self.good_r) + + def test_good_mock(self): + async def run(): + await self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None) + + self.async_run(run) + + def test_bad_address(self): + async def run(): + await self.mock_receive( + self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + self.async_run(run) + + def test_bad_address_not_ignored(self): + async def abad(): + await self.mock_receive( + self.good_r_wire, + ("127.0.0.2", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_unexpected=False, + ) + + def bad(): + self.async_run(abad) + + self.assertRaises(dns.query.UnexpectedSource, bad) + + def test_not_response_not_ignored_udp_level(self): + async def abad(): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + s = MockSock( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + await dns.asyncquery.udp(self.good_r, "127.0.0.1", sock=s) + + def bad(): + self.async_run(abad) + + self.assertRaises(dns.query.BadResponse, bad) + + def test_bad_id(self): + async def run(): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + await self.mock_receive( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + self.async_run(run) + + def test_bad_id_not_ignored(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + + async def abad(): + (r, wire) = await self.mock_receive( + bad_r_wire, + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_errors=False, + ) + + def bad(): + self.async_run(abad) + + self.assertRaises(AssertionError, bad) + + def test_bad_wire(self): + async def run(): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + await self.mock_receive( + bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + self.async_run(run) + + def test_bad_wire_not_ignored(self): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + + async def abad(): + await self.mock_receive( + bad_r_wire[:10], + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_errors=False, + ) + + def bad(): + self.async_run(abad) + + self.assertRaises(dns.message.ShortHeader, bad) + + def test_trailing_wire(self): + async def run(): + wire = self.good_r_wire + b"abcd" + await self.mock_receive( + wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ) + + self.async_run(run) + + def test_trailing_wire_not_ignored(self): + wire = self.good_r_wire + b"abcd" + + async def abad(): + await self.mock_receive( + wire, + ("127.0.0.1", 53), + self.good_r_wire, + ("127.0.0.1", 53), + ignore_errors=False, + ) + + def bad(): + self.async_run(abad) + + self.assertRaises(dns.message.TrailingJunk, bad) diff --git a/tests/test_query.py b/tests/test_query.py index a47daa45..1039a14e 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -683,6 +683,14 @@ def mock(sock, max_size, expiration): dns.query._udp_recv = saved +class MockSock: + def __init__(self): + self.family = socket.AF_INET + + def sendto(self, data, where): + return len(data) + + class IgnoreErrors(unittest.TestCase): def setUp(self): self.q = dns.message.make_query("example.", "A") @@ -758,6 +766,19 @@ def bad(): self.assertRaises(AssertionError, bad) + def test_not_response_not_ignored_udp_level(self): + def bad(): + bad_r = dns.message.make_response(self.q) + bad_r.id += 1 + bad_r_wire = bad_r.to_wire() + with mock_udp_recv( + bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53) + ): + s = MockSock() + dns.query.udp(self.good_r, "127.0.0.1", sock=s) + + self.assertRaises(dns.query.BadResponse, bad) + def test_bad_wire(self): bad_r = dns.message.make_response(self.q) bad_r.id += 1