diff --git a/tests/service_mocks/test_udp_service_mock.py b/tests/service_mocks/test_udp_service_mock.py new file mode 100644 index 0000000..3748abb --- /dev/null +++ b/tests/service_mocks/test_udp_service_mock.py @@ -0,0 +1,23 @@ +import socket + +from threat9_test_bed.service_mocks import UDPServiceMock + + +def test_udp_service_mock_add_banner(): + with UDPServiceMock("127.0.0.1", 8023) as target: + assert target.host == "127.0.0.1" + assert target.port == 8023 + + mocked_doo = target.get_command_mock(b"doo") + mocked_doo.return_value = b"where are you?" + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect((target.host, target.port)) + s.send(b"doo") + assert s.recv(1024) == b"where are you?" + + mocked_scoo = target.get_command_mock(b"scoo") + mocked_scoo.return_value = b"bee" + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect((target.host, target.port)) + s.send(b"scoo") + assert s.recv(1024) == b"bee" diff --git a/threat9_test_bed/service_mocks/__init__.py b/threat9_test_bed/service_mocks/__init__.py index 035f90e..79f8dcc 100644 --- a/threat9_test_bed/service_mocks/__init__.py +++ b/threat9_test_bed/service_mocks/__init__.py @@ -2,3 +2,4 @@ from .http_service_mock import HttpServiceMock # noqa: F401 from .tcp_service_mock import TCPServiceMock # noqa: F401 from .telnet_service_mock import TelnetServiceMock # noqa: F401 +from .udp_service_mock import UDPServiceMock # noqa: F401 diff --git a/threat9_test_bed/service_mocks/base_service.py b/threat9_test_bed/service_mocks/base_service.py index f4ec0d9..01d5712 100644 --- a/threat9_test_bed/service_mocks/base_service.py +++ b/threat9_test_bed/service_mocks/base_service.py @@ -6,6 +6,9 @@ class BaseService: + + socket_type = socket.SOCK_STREAM + def __init__(self, host: str, port: int): self.host = host self.port, self.dibbed_port_socket = self.dib_port(port) @@ -14,7 +17,7 @@ def _wait_for_service(self): elapsed_time = 0 start_time = time.time() while elapsed_time < 5: - s = socket.socket() + s = socket.socket(type=self.socket_type) s.settimeout(1) try: s.connect((self.host, self.port)) diff --git a/threat9_test_bed/service_mocks/udp_service_mock.py b/threat9_test_bed/service_mocks/udp_service_mock.py new file mode 100644 index 0000000..b5a3332 --- /dev/null +++ b/threat9_test_bed/service_mocks/udp_service_mock.py @@ -0,0 +1,33 @@ +from logging import getLogger +import socket +import threading +from unittest import mock + +from ..udp_service.udp_server import UDPHandler, UDPServer +from .base_service import BaseService + +logger = getLogger(__name__) + + +class UDPServiceMock(BaseService): + + socket_type = socket.SOCK_DGRAM + + def __init__(self, host: str, port: int): + super().__init__(host, port) + self.server = UDPServer((self.host, self.port), UDPHandler, False) + self.server_thread = threading.Thread(target=self.server.serve_forever) + + def start(self): + self.server.server_bind() + self.server.server_activate() + self.server_thread.start() + + def teardown(self): + self.server.shutdown() + self.server_thread.join() + self.server.server_close() + + def get_command_mock(self, command: bytes) -> mock.Mock: + logger.debug(f"{self} mock for '{command}' has been added.") + return self.server.get_command_mock(command) diff --git a/threat9_test_bed/udp_service/__init__.py b/threat9_test_bed/udp_service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/threat9_test_bed/udp_service/udp_server.py b/threat9_test_bed/udp_service/udp_server.py new file mode 100644 index 0000000..df62361 --- /dev/null +++ b/threat9_test_bed/udp_service/udp_server.py @@ -0,0 +1,32 @@ +from logging import getLogger +import socketserver +from unittest import mock + +logger = getLogger(__name__) + + +class UDPServer(socketserver.ThreadingUDPServer): + allow_reuse_address = True + + def __init__( + self, + server_address, + request_handler_class, + bind_and_activate=True, + ): + super().__init__( + server_address, request_handler_class, bind_and_activate, + ) + self.handlers = {} + + def get_command_mock(self, command: bytes) -> mock.Mock: + mocked_handler = mock.MagicMock(name=command) + self.handlers[command] = mocked_handler + return mocked_handler + + +class UDPHandler(socketserver.DatagramRequestHandler): + def handle(self): + data = self.rfile.read() + handler = self.server.handlers[data] + self.wfile.write(handler())