diff --git a/.exampleEnv b/.exampleEnv new file mode 100644 index 0000000..c82ae15 --- /dev/null +++ b/.exampleEnv @@ -0,0 +1,10 @@ +# +# An example .env file. +# With all module used keys. +# + +HANDLER = MirrorHandler +SERVER = SocketServer + +HOSTNAME = localhost +PORT = 80 \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..dfedd12 --- /dev/null +++ b/main.py @@ -0,0 +1,4 @@ +import protoserver + +server = protoserver.ProtoServer() +server.run() \ No newline at end of file diff --git a/protoserver/__init__.py b/protoserver/__init__.py new file mode 100644 index 0000000..2da492f --- /dev/null +++ b/protoserver/__init__.py @@ -0,0 +1,6 @@ +from protoserver.config import * +from protoserver.protoServer import * + +from protoserver import handlers +from protoserver import http +from protoserver import servers \ No newline at end of file diff --git a/protoserver/config.py b/protoserver/config.py new file mode 100644 index 0000000..34896ef --- /dev/null +++ b/protoserver/config.py @@ -0,0 +1,67 @@ +from __future__ import annotations +import dotenv +from typing import Mapping, TypeVar, TYPE_CHECKING + +if TYPE_CHECKING: + from protoserver.handlers.iHandler import IHandler + from protoserver.servers.iServer import IServer + + +_T = TypeVar("_T") + +class Config(object): + + MINIMAL_VALID_CONFIG_OPTIONS = { + "HANDLER": "MirrorHandler", + "SERVER": "FunctionServer", + } + + _registeredHandlers: dict[str, type[IHandler]] = {} + _registeredServers: dict[str, type[IServer]] = {} + + def __init__(self, data: Mapping[str, None | str]) -> None: + self._config = data + + possibleHandlers = self._registeredHandlers.keys() + if self._config["HANDLER"] not in possibleHandlers: + raise ValueError(f"HANDLER not correctly set in .env.\nMust be one of:\n\t{', '.join(possibleHandlers)}") + + possibleServers = self._registeredServers.keys() + if self._config["SERVER"] not in possibleServers: + raise ValueError(f"SERVER not correctly set in .env.\nMust be one of:\n\t{', '.join(possibleServers)}") + + @classmethod + def loadFromFile(cls, path = ".env") -> Config: + return Config(dotenv.dotenv_values(path)) + + @classmethod + def getMinimalValidConfig(cls, data: None | Mapping[str, None | str] = None) -> Config: + if data is None: data = {} + return Config(Config.MINIMAL_VALID_CONFIG_OPTIONS | data) + + + def get(self, key: str, type: type[_T] = str) -> _T: + if key not in self._config: + raise KeyError(f"Unknown config option {key}.\nMaybe it wasn't in .exampleEnv?") + + value = self._config[key] + if value is None: + raise ValueError(f"{key} has no default value, and wasn't set in .env.") + + return type(value) + + def getHandler(self) -> type[IHandler]: + handler = self.get("HANDLER") + return self._registeredHandlers[handler] + + def getServer(self) -> type[IServer]: + server = self.get("SERVER") + return self._registeredServers[server] + + @classmethod + def registerHandler(cls, name: str, handler: type[IHandler]) -> None: + cls._registeredHandlers[name] = handler + + @classmethod + def registerServer(cls, name: str, server: type[IServer]) -> None: + cls._registeredServers[name] = server \ No newline at end of file diff --git a/protoserver/handlers/__init__.py b/protoserver/handlers/__init__.py new file mode 100644 index 0000000..2cb8bac --- /dev/null +++ b/protoserver/handlers/__init__.py @@ -0,0 +1,2 @@ +from protoserver.handlers.iHandler import IHandler +from protoserver.handlers.mirrorHandler import MirrorHandler \ No newline at end of file diff --git a/protoserver/handlers/iHandler.py b/protoserver/handlers/iHandler.py new file mode 100644 index 0000000..54c59b7 --- /dev/null +++ b/protoserver/handlers/iHandler.py @@ -0,0 +1,21 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from protoserver import Config + from protoserver.http import Request, Response, StatusCode + + +class IHandler(ABC): + + def __init__(self, config: Config) -> None: + self.config = config + + @abstractmethod + def handleRequest(self, request: Request) -> None | Response: + raise NotImplementedError() + + @abstractmethod + def handleError(self, statusCode: StatusCode) -> None | Response: + raise NotImplementedError() \ No newline at end of file diff --git a/protoserver/handlers/mirrorHandler.py b/protoserver/handlers/mirrorHandler.py new file mode 100644 index 0000000..f601898 --- /dev/null +++ b/protoserver/handlers/mirrorHandler.py @@ -0,0 +1,25 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from protoserver.config import Config +from protoserver.handlers.iHandler import IHandler +from protoserver.http.response import Response + +if TYPE_CHECKING: + from protoserver.http.request import Request + from protoserver.http.statusCodes import StatusCode + + +class MirrorHandler(IHandler): + + def handleRequest(self, request: Request) -> Response | None: + response = Response() + + response.setBody(request.raw.decode("iso-8859-1")) + + return response + + def handleError(self, statusCode: StatusCode) -> Response | None: + return super().handleError(statusCode) + +Config.registerHandler("MirrorHandler", MirrorHandler) \ No newline at end of file diff --git a/protoserver/http/__init__.py b/protoserver/http/__init__.py new file mode 100644 index 0000000..e67d310 --- /dev/null +++ b/protoserver/http/__init__.py @@ -0,0 +1,4 @@ +from protoserver.http.client import Client +from protoserver.http.request import Request +from protoserver.http.response import Response +from protoserver.http.statusCodes import StatusCode \ No newline at end of file diff --git a/protoserver/http/client.py b/protoserver/http/client.py new file mode 100644 index 0000000..74f7c90 --- /dev/null +++ b/protoserver/http/client.py @@ -0,0 +1,63 @@ +from __future__ import annotations +from threading import Thread +from typing import TYPE_CHECKING + +from protoserver.http.request import Request +from protoserver.http.response import Response +from protoserver.http.statusCodes import StatusCode + +if TYPE_CHECKING: + from protoserver import Config + from protoserver.handlers import IHandler + from protoserver.servers import IConnection + + +class Client(object): + + def __init__(self, config: Config, connection: IConnection, handler: IHandler) -> None: + self.config = config + self.connection = connection + self.handler = handler + + self.isRunning = True + self.thread = Thread(target = self.loop) + self.thread.start() + + def loop(self) -> None: + + while self.isRunning: + request = self.recv() + if request is None: + continue + + response = self.handler.handleRequest(request) + if response is None: + continue + + self.send(response) + + self.close() + + def recv(self) -> None | Request: + requestBytes = self.connection.recv() + if requestBytes == b"": + self.isRunning = False + return None + + request = Request(requestBytes) + + match (request.status): + case "OK": + return request + case "BAD_REQUEST": + response = self.handler.handleError(StatusCode.BAD_REQUEST) + if response is not None: + self.send(response) + return None + + def send(self, response: Response) -> None: + self.connection.send(response.build()) + + def close(self) -> None: + self.connection.close() + diff --git a/protoserver/http/request.py b/protoserver/http/request.py new file mode 100644 index 0000000..d539f08 --- /dev/null +++ b/protoserver/http/request.py @@ -0,0 +1,215 @@ +from __future__ import annotations +from typing import Literal + + +_Context = Literal["START", "METHOD", "PATH", "QUERYNAME", "QUERYVALUE", "VERSION", "SKIPCHECK", "SKIP", "HEADERNAME", "HEADERVALUE", "BODY"] +_Token = tuple[_Context, bytes] + +class Request(object): + + def __init__(self, requestBytes: bytes) -> None: + + self.status = "OK" + self.raw = requestBytes + + readBytes: list[bytes] = [] + tokens: list[_Token] = [] + byteContext: _Context = "START" + for byte in map(lambda b: chr(b).encode(), requestBytes): + readBytes.append(byte) + + res = self._parseReadBytes(byteContext, readBytes) + if res is None: + continue + + byteContext = res[0] + if res[1] is not None: + tokens.append(res[1]) + + if byteContext != "BODY": + self.status = "BAD_REQUEST" + + self.headers = {} + self.queries = {} + + try: + self._parseTokens(tokens) + except Exception as e: + print(e) + self.status = "BAD_REQUEST" + return + + self._parseBody(readBytes) + + def _parseReadBytes(self, context: _Context, readBytes: list[bytes]) -> None | tuple[_Context, None | _Token]: + match (context): + case "START": + match (readBytes): + case [b"\r", b"\n"]: + readBytes.clear() + return "START", None + case _: + return "METHOD", None + + case "METHOD": + match (readBytes): + case [*b, b" "]: + readBytes.clear() + return "PATH", ("METHOD", b"".join(b)) + + case "PATH": + match (readBytes): + case [*b, b"?"]: + readBytes.clear() + return "QUERYNAME", ("PATH", b"".join(b)) + case [*b, b" "]: + readBytes.clear() + return "VERSION", ("PATH", b"".join(b)) + case [*b, b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", ("PATH", b"".join(b)) + + case "QUERYNAME": + match (readBytes): + case [*b, b"&"]: + readBytes.clear() + return "QUERYNAME", ("QUERYNAME", b"".join(b)) + case [*b, b"="]: + readBytes.clear() + return "QUERYVALUE", ("QUERYNAME", b"".join(b)) + case [*b, b" "]: + readBytes.clear() + return "VERSION", ("QUERYNAME", b"".join(b)) + case [*b, b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", ("QUERYNAME", b"".join(b)) + + case "QUERYVALUE": + match (readBytes): + case [*b, b"&"]: + readBytes.clear() + return "QUERYNAME", ("QUERYVALUE", b"".join(b)) + case [*b, b" "]: + readBytes.clear() + return "VERSION", ("QUERYVALUE", b"".join(b)) + case [*b, b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", ("QUERYVALUE", b"".join(b)) + + case "VERSION": + match (readBytes): + case [*b, b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", ("VERSION", b"".join(b)) + + case "SKIPCHECK": + match (readBytes): + case [b" "]: + readBytes.clear() + return "SKIP", None + case [b"\r"]: + return + case [b"\r", b"\n"]: + readBytes.clear() + return "BODY", None + case _: + return "HEADERNAME", None + + case "SKIP": + match (readBytes): + case [*b, b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", None + + case "HEADERNAME": + match (readBytes): + case [*b, b":"]: + readBytes.clear() + return "HEADERVALUE", ("HEADERNAME", b"".join(b)) + + case "HEADERVALUE": + match (readBytes): + case [b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", None + case [*b, b"\r", b"\n"]: + readBytes.clear() + return "SKIPCHECK", ("HEADERVALUE", b"".join(b)) + + def _parseTokens(self, tokens: list[_Token]) -> None: + + t = tokens.pop(0) + if t[0] != "METHOD": raise Exception + self.method = self._decodeBytes(t[1]).strip() + + t = tokens.pop(0) + if t[0] != "PATH": raise Exception + self.path = self._decodeUrlEncodedBytes(t[1]).strip() + + t = tokens.pop(0) + while t[0] == "QUERYNAME": + qName = self._decodeUrlEncodedBytes(t[1]).strip() + if qName not in self.queries: + self.queries[qName] = [] + + if not len(tokens): + self.queries[qName].append(None) + return + t = tokens.pop(0) + + if t[0] != "QUERYVALUE": + self.queries[qName].append(None) + continue + qValue = self._decodeUrlEncodedBytes(t[1]).strip() + self.queries[qName].append(qValue) + if not len(tokens): return + t = tokens.pop(0) + + if t[0] == "VERSION": + self.version = self._decodeBytes(t[1]).strip() + if not len(tokens): return + t = tokens.pop(0) + + while t[0] == "HEADERNAME": + hName = self._decodeBytes(t[1]).strip() + if hName not in self.headers: + self.headers[hName] = [] + + if not len(tokens): + self.headers[hName].append(None) + return + t = tokens.pop(0) + + if t[0] != "HEADERVALUE": + self.headers[hName].append(None) + continue + hValue = self._decodeBytes(t[1]).strip() + self.headers[hName].append(hValue) + + if not len(tokens): return + t = tokens.pop(0) + + def _parseBody(self, bodyBytes: list[bytes]) -> None: + self.body = self._decodeBytes(b"".join(bodyBytes)) + # TODO, properly parse body using Transfer-Encoding, and Content-Length headers. + + def _decodeBytes(self, encodedBytes: bytes) -> str: + return encodedBytes.decode("iso-8859-1") + + def _decodeUrlEncodedBytes(self, encodedBytes: bytes) -> str: + urlEncodedString = self._decodeBytes(encodedBytes) + decodedString = "" + + i = 0 + while i < len(urlEncodedString): + char = urlEncodedString[i] + + if char != "%": + decodedString += char + i += 1 + continue + + decodedString += chr(int(urlEncodedString[i+1:i+3], 16)) + i += 3 + + return decodedString \ No newline at end of file diff --git a/protoserver/http/response.py b/protoserver/http/response.py new file mode 100644 index 0000000..29a41cd --- /dev/null +++ b/protoserver/http/response.py @@ -0,0 +1,51 @@ +from __future__ import annotations +from typing import Self + +from protoserver.http.statusCodes import StatusCode + + +class Response(object): + + def __init__(self) -> None: + + self.statusCode = StatusCode.OK + self.headers = {} + self.body = "" + + def setStatusCode(self, statusCode: StatusCode) -> Self: + self.statusCode = statusCode + return self + + def setHeader(self, header: str, value: str) -> Self: + self.headers[header] = value + return self + + def setBody(self, body: str) -> Self: + self.body = body + + bodyLength = len(self._buildBody()) + self.setHeader("Content-Length", str(bodyLength)) + + return self + + def build(self) -> bytes: + return b"".join([ + self._buildStatusLine(), + self._buildHeaders(), + self._buildBody(), + ]) + + def _buildStatusLine(self) -> bytes: + statusLine = f"HTTP/1.1 {self.statusCode.value} {self.statusCode.name}\r\n" + return statusLine.encode("iso-8859-1") + + def _buildHeaders(self) -> bytes: + headerString = "" + for name, value in self.headers.items(): + headerString += f"{name}: {value}\r\n" + + headerString += "\r\n" + return headerString.encode("iso-8859-1") + + def _buildBody(self) -> bytes: + return self.body.encode("iso-8859-1") \ No newline at end of file diff --git a/protoserver/http/statusCodes.py b/protoserver/http/statusCodes.py new file mode 100644 index 0000000..1638f78 --- /dev/null +++ b/protoserver/http/statusCodes.py @@ -0,0 +1,54 @@ +from __future__ import annotations +from enum import Enum + + +class StatusCode(Enum): + + CONTINUE = 100 + SWITCHING_PROTOCOLS = 101 + + OK = 200 + CREATED = 201 + ACCEPTED = 202 + NON_AUTHORITATIVE_INFORMATION = 203 + NO_CONTENT = 204 + RESET_CONTENT = 205 + PARTIAL_CONTENT = 206 + + MULTIPLE_CHOICE = 300 + MOVED_PERMANENTLY = 301 + FOUND = 302 + SEE_OTHER = 303 + NOT_MODIFIED = 304 + USE_PROXY = 305 + TEMPORARY_REDIRECT = 307 + PERMANENT_REDIRECT = 308 + + BAD_REQUEST = 400 + UNAUTHORIZED = 401 + PAYMENT_REQUIRED = 402 + FORBIDDEN = 403 + NOT_FOUND = 404 + METHOD_NOT_ALLOWED = 405 + NOT_ACCEPTABLE = 406 + PROXY_AUTHENTICATION_REQUIRED = 407 + REQUEST_TIMEOUT = 408 + CONFLICT = 409 + GONE = 410 + LENGTH_REQUIRED = 411 + PRECONDITION_FAILED = 412 + CONTENT_TOO_LARGE = 413 + URI_TOO_LONG = 414 + UNSUPPORTED_MEDIA_TYPE = 415 + RANGE_NOT_SATISFIABLE = 416 + EXPECTATION_FAILED = 417 + MISDIRECTED_REQUEST = 421 + UNPROCESSABLE_CONTENT = 422 + UPGRADE_REQUIRED = 426 + + INTERNAL_SERVER_ERROR = 500 + NOT_IMPLEMENTED = 501 + BAD_GATEWAY = 502 + SERVICE_UNAVAILABLE = 503 + GATEWAY_TIMEOUT = 504 + HTTP_VERSION_NOT_SUPPORTED = 505 \ No newline at end of file diff --git a/protoserver/protoServer.py b/protoserver/protoServer.py new file mode 100644 index 0000000..19c2c5d --- /dev/null +++ b/protoserver/protoServer.py @@ -0,0 +1,35 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from protoserver.config import Config +from protoserver.http.client import Client + +if TYPE_CHECKING: + from protoserver.servers.iServer import IConnection + + +class ProtoServer(object): + + def __init__(self, config: None | Config = None) -> None: + if config is None: config = Config.loadFromFile() + self.config = config + + self.handler = self.config.getHandler()(self.config) + self.server = self.config.getServer()(self.config) + + self.isRunning = False + self.connections: list[IConnection] = [] + + def run(self) -> None: + + self.isRunning = True + while self.isRunning: + connection = self.server.accept() + if connection is None: + break + client = Client(self.config, connection, self.handler) + + self.stop() + + def stop(self) -> None: + self.server.stop() \ No newline at end of file diff --git a/protoserver/servers/__init__.py b/protoserver/servers/__init__.py new file mode 100644 index 0000000..b42ddae --- /dev/null +++ b/protoserver/servers/__init__.py @@ -0,0 +1,3 @@ +from protoserver.servers.iServer import * +from protoserver.servers.function import * +from protoserver.servers.socket import * \ No newline at end of file diff --git a/protoserver/servers/function.py b/protoserver/servers/function.py new file mode 100644 index 0000000..e6ad330 --- /dev/null +++ b/protoserver/servers/function.py @@ -0,0 +1,49 @@ +from __future__ import annotations +from queue import Queue + +from protoserver.config import Config +from protoserver.servers.iServer import IConnection, IServer + + +class FunctionConnection(IConnection): + + requestBuffer = Queue() + responseBuffer = Queue() + + def recv(self) -> bytes: + return self.requestBuffer.get() + + def send(self, data: bytes) -> None: + self.responseBuffer.put(data) + + def stop(self) -> None: + return + + def clientSend(self, data: bytes) -> None: + self.requestBuffer.put(data) + + def clientRecv(self) -> bytes: + return self.responseBuffer.get() + + +class FunctionServer(IServer): + + def __init__(self, *args) -> None: + super().__init__(*args) + + self.connection = None + + def accept(self) -> None | FunctionConnection: + if self.connection is not None: + return None + + self.connection = FunctionConnection(self.config, self) + self.connections.append(self.connection) + return self.connection + + def getConnection(self) -> FunctionConnection: + while self.connection is None: + pass + return self.connection + +Config.registerServer("FunctionServer", FunctionServer) \ No newline at end of file diff --git a/protoserver/servers/iServer.py b/protoserver/servers/iServer.py new file mode 100644 index 0000000..982120c --- /dev/null +++ b/protoserver/servers/iServer.py @@ -0,0 +1,42 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from protoserver.config import Config + + +class IConnection(ABC): + + def __init__(self, config: Config, server: IServer) -> None: + self.config = config + self.server = server + + @abstractmethod + def recv(self) -> bytes: + raise NotImplementedError() + + @abstractmethod + def send(self, data: bytes) -> None: + raise NotImplementedError() + + def close(self) -> None: + self.server.onConnectionClose(self) + + +class IServer(ABC): + + def __init__(self, config: Config) -> None: + self.config = config + self.connections: list[IConnection] = [] + + @abstractmethod + def accept(self) -> None | IConnection: + raise NotImplementedError() + + def stop(self) -> None: + for connection in self.connections.copy(): + connection.close() + + def onConnectionClose(self, connection: IConnection) -> None: + self.connections.remove(connection) diff --git a/protoserver/servers/socket.py b/protoserver/servers/socket.py new file mode 100644 index 0000000..6d794bd --- /dev/null +++ b/protoserver/servers/socket.py @@ -0,0 +1,59 @@ +from __future__ import annotations +from socket import create_server, socket + +from protoserver.config import Config +from protoserver.servers.iServer import IConnection, IServer + + +class SocketConnection(IConnection): + + def __init__(self, config: Config, server: SocketServer, connection: socket, clientName: str, port: int) -> None: + super().__init__(config, server) + + self.connection = connection + self.clientName = clientName + self.port = port + + def recv(self) -> bytes: + return self.connection.recv(8192) + + def send(self, data: bytes) -> None: + self.connection.sendall(data) + + def close(self) -> None: + super().close() + self.connection.close() + + +class SocketServer(IServer): + + def __init__(self, config: Config) -> None: + super().__init__(config) + + self.hostname = config.get("HOSTNAME") + self.port = config.get("PORT", int) + + self.server = socket() + + try: + self.server.bind((self.hostname, self.port)) + except: + self.server.close() + raise + + self.server.listen() + + def accept(self) -> None | SocketConnection: + try: + connection, address = self.server.accept() + except OSError: + return None + conn = SocketConnection(self.config, self, connection, address[0], address[1]) + self.connections.append(conn) + return conn + + def stop(self) -> None: + super().stop() + self.server.close() + +Config.registerServer("SocketServer", SocketServer) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/testConfig.py b/tests/testConfig.py new file mode 100644 index 0000000..27a8957 --- /dev/null +++ b/tests/testConfig.py @@ -0,0 +1,35 @@ +from unittest import TestCase + +from protoserver.config import Config +from protoserver.servers.function import FunctionServer + + +class TestConfig(TestCase): + + def testNoServerAttribute(self) -> None: + c = {"HANDLER": "HTMLMirrorHandler"} + self.assertRaises(ValueError, lambda: Config(c)) + + def testInvalidServerAttribute(self) -> None: + c = {"SERVER": "InvalidServer", "HANDLER": "HTMLMirrorHandler"} + self.assertRaises(ValueError, lambda: Config(c)) + + def testValidServerAttribute(self) -> None: + c = Config.MINIMAL_VALID_CONFIG_OPTIONS + self.assertEqual(Config(c).getServer(), FunctionServer) + + def testInvalidAttribute(self) -> None: + c = Config.MINIMAL_VALID_CONFIG_OPTIONS + self.assertRaises(KeyError, lambda: Config(c).get("InvalidAttribute")) + + def testNoneAttribute(self) -> None: + c = Config.MINIMAL_VALID_CONFIG_OPTIONS | {"NoneAttribute": None} + self.assertRaises(ValueError, lambda: Config(c).get("NoneAttribute")) + + def testValidAttribute(self) -> None: + c = Config.MINIMAL_VALID_CONFIG_OPTIONS | {"ValidAttribute": "Value"} + self.assertEqual(Config(c).get("ValidAttribute"), "Value") + + def testValidAttributeTypeCast(self) -> None: + c = Config.MINIMAL_VALID_CONFIG_OPTIONS | {"IntAttribute": "1"} + self.assertEqual(Config(c).get("IntAttribute", int), 1) diff --git a/tests/testFunctionConnection.py b/tests/testFunctionConnection.py new file mode 100644 index 0000000..2d026e2 --- /dev/null +++ b/tests/testFunctionConnection.py @@ -0,0 +1,56 @@ +from unittest import TestCase + +from protoserver.config import Config +from protoserver.servers.function import FunctionServer + + +class testFunctionConnection(TestCase): + + def setUp(self) -> None: + c = Config.getMinimalValidConfig() + self.s = FunctionServer(c) + self.s.accept() + + def tearDown(self) -> None: + self.s.stop() + del self.s + + def testClientSend(self) -> None: + connection = self.s.getConnection() + + connection.clientSend(b"Request Bytes") + self.assertEqual(connection.recv(), b"Request Bytes") + + def testClientRecv(self) -> None: + connection = self.s.getConnection() + + connection.send(b"Response Bytes") + self.assertEqual(connection.clientRecv(), b"Response Bytes") + + def testSequentialClientSend(self) -> None: + connection = self.s.getConnection() + + connection.clientSend(b"Request Bytes 1") + connection.clientSend(b"Request Bytes 2") + self.assertEqual(connection.recv(), b"Request Bytes 1") + self.assertEqual(connection.recv(), b"Request Bytes 2") + + def testSequentialClientRecv(self) -> None: + connection = self.s.getConnection() + + connection.send(b"Response Bytes 1") + connection.send(b"Response Bytes 2") + self.assertEqual(connection.clientRecv(), b"Response Bytes 1") + self.assertEqual(connection.clientRecv(), b"Response Bytes 2") + + def testMixedSendRecv(self) -> None: + connection = self.s.getConnection() + + connection.clientSend(b"Request Bytes 1") + connection.send(b"Response Bytes 1") + self.assertEqual(connection.clientRecv(), b"Response Bytes 1") + connection.clientSend(b"Request Bytes 2") + connection.send(b"Response Bytes 2") + self.assertEqual(connection.recv(), b"Request Bytes 1") + self.assertEqual(connection.clientRecv(), b"Response Bytes 2") + self.assertEqual(connection.recv(), b"Request Bytes 2") \ No newline at end of file diff --git a/tests/testFunctionServer.py b/tests/testFunctionServer.py new file mode 100644 index 0000000..169f351 --- /dev/null +++ b/tests/testFunctionServer.py @@ -0,0 +1,23 @@ +from unittest import TestCase + +from protoserver.config import Config +from protoserver.servers.function import FunctionServer + +class testFunctionServer(TestCase): + + def setUp(self) -> None: + c = Config.getMinimalValidConfig() + self.s = FunctionServer(c) + + def tearDown(self) -> None: + self.s.stop() + del self.s + + def testFirstConnectionIsSetConnection(self) -> None: + connection = self.s.accept() + self.assertEqual(connection, self.s.getConnection()) + + def testNoneAfterFirstConnection(self) -> None: + self.s.accept() + secondConnection = self.s.accept() + self.assertIsNone(secondConnection) \ No newline at end of file diff --git a/tests/testHTTPRequest.py b/tests/testHTTPRequest.py new file mode 100644 index 0000000..e59cc08 --- /dev/null +++ b/tests/testHTTPRequest.py @@ -0,0 +1,112 @@ +from unittest import TestCase + +from protoserver.http.request import Request + +class testFunctionServer(TestCase): + + def assertRequestEquals(self, request: Request, method: str, path: str, queries: dict, version: str, headers: dict, body: str): + self.assertEqual(request.method, method) + self.assertEqual(request.path, path) + self.assertDictEqual(request.queries, queries) + self.assertEqual(request.version, version) + self.assertDictEqual(request.headers, headers) + self.assertEqual(request.body, body) + self.assertEqual(request.status, "OK") + + + def testBadRequests(self) -> None: + requestBytes = [ + b"", + b"/index.html", + b"GET /index.html HTTP/1.1\r\nHost: localhost", + b"GET /index.html HTTP/1.1\r\nHost: localhost\r\n", + ] + + for rBytes in requestBytes: + with self.subTest(): + request = Request(rBytes) + self.assertEqual(request.status, "BAD_REQUEST") + + def testMinimalRequest(self) -> None: + requestBytes = b"GET /index.html HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {}, "") + + def testMethod(self) -> None: + requestBytes = b"POST /index.html HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "POST", "/index.html", {}, "HTTP/1.1", {}, "") + + def testUrlEncodedPath(self) -> None: + requestBytes = b"GET %2f%69%6e%64%65%78%2e%68%74%6d%6c HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {}, "") + + def testUrlEncodedPercent(self) -> None: + requestBytes = b"GET %25%32%35 HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "%25", {}, "HTTP/1.1", {}, "") + + def testNoValueQuery(self) -> None: + requestBytes = b"GET /index.html?q HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {"q": [None]}, "HTTP/1.1", {}, "") + + def testSingleQuery(self) -> None: + requestBytes = b"GET /index.html?q=1 HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {"q": ["1"]}, "HTTP/1.1", {}, "") + + def testMultipleQuery(self) -> None: + requestBytes = b"GET /index.html?p=query1&q=query2 HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {"p": ["query1"], "q": ["query2"]}, "HTTP/1.1", {}, "") + + def testDuplicateQuery(self) -> None: + requestBytes = b"GET /index.html?p=1&p=2&p&p&p=3 HTTP/1.1\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {"p": ["1", "2", None, None, "3"]}, "HTTP/1.1", {}, "") + + def testNoValueHeader(self) -> None: + requestBytes = b"GET /index.html HTTP/1.1\r\nHost:\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {"Host": [None]}, "") + + def testSingleHeader(self) -> None: + requestBytes = b"GET /index.html HTTP/1.1\r\nHost: localhost\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {"Host": ["localhost"]}, "") + + def testMultipleHeader(self) -> None: + requestBytes = b"GET /index.html HTTP/1.1\r\nHost: localhost\r\nHeader: Value\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {"Host": ["localhost"], "Header": ["Value"]}, "") + + def testDuplicateHeader(self) -> None: + requestBytes = b"GET /index.html HTTP/1.1\r\nHeader: Value1\r\nHeader: Value2\r\n\r\n" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {"Header": ["Value1", "Value2"]}, "") + + def testBody(self) -> None: + requestBytes = b"GET /index.html HTTP/1.1\r\n\r\nBody Text" + + request = Request(requestBytes) + self.assertRequestEquals(request, "GET", "/index.html", {}, "HTTP/1.1", {}, "Body Text") + + def testFullRequest(self) -> None: + requestBytes = b"POST /%69%6e%64%65%78.html?%69=%69%20%69&q&q=1 HTTP/1.1\r\nHost: localhost\r\nContent-Length: 22\r\n\r\nThis is 22 characters." + + request = Request(requestBytes) + self.assertRequestEquals(request, "POST", "/index.html", {"i": ["i i"], "q": [None, "1"]}, "HTTP/1.1", {"Host": ["localhost"], "Content-Length": ["22"]}, "This is 22 characters.") \ No newline at end of file diff --git a/tests/testHTTPResponse.py b/tests/testHTTPResponse.py new file mode 100644 index 0000000..30e9b3b --- /dev/null +++ b/tests/testHTTPResponse.py @@ -0,0 +1,48 @@ +from unittest import TestCase + +from protoserver.http.response import Response +from protoserver.http.statusCodes import StatusCode + + +class testFunctionServer(TestCase): + + def testMinimalRequest(self) -> None: + response = Response()\ + .build() + + self.assertEqual(b"HTTP/1.1 200 OK\r\n\r\n", response) + + def testStatusCodes(self) -> None: + for code in StatusCode: + with self.subTest(): + response = Response()\ + .setStatusCode(code)\ + .build() + + expectedResponse = f"HTTP/1.1 {code.value} {code.name}\r\n\r\n".encode("iso-8859-1") + self.assertEqual(expectedResponse, response) + + def testSetSingleHeader(self) -> None: + response = Response()\ + .setHeader("X-Custom-Header", "Value")\ + .build() + + expectedResponse = b"HTTP/1.1 200 OK\r\nX-Custom-Header: Value\r\n\r\n" + self.assertEqual(expectedResponse, response) + + def testSetMultipleHeader(self) -> None: + response = Response()\ + .setHeader("X-Custom-Header", "Value")\ + .setHeader("X-Custom-Header2", "Value2")\ + .build() + + expectedResponse = b"HTTP/1.1 200 OK\r\nX-Custom-Header: Value\r\nX-Custom-Header2: Value2\r\n\r\n" + self.assertEqual(expectedResponse, response) + + def testSetBody(self) -> None: + response = Response()\ + .setBody("Body Text")\ + .build() + + expectedResponse = b"HTTP/1.1 200 OK\r\nContent-Length: 9\r\n\r\nBody Text" + self.assertEqual(expectedResponse, response) \ No newline at end of file diff --git a/tests/testMirrorHandler.py b/tests/testMirrorHandler.py new file mode 100644 index 0000000..c8e1148 --- /dev/null +++ b/tests/testMirrorHandler.py @@ -0,0 +1,32 @@ +from queue import Queue +from socket import socket +from threading import Thread +from unittest import TestCase + +from protoserver.config import Config +from protoserver.handlers import MirrorHandler +from protoserver.http.request import Request + + +class testMirrorHandler(TestCase): + + def testMirrorHandler(self) -> None: + config = Config.getMinimalValidConfig({"HANDLER": "MirrorHandler"}) + handler = MirrorHandler(config) + + requestBytes = [ + b"GET / HTTP/1.1\r\n\r\n", + b"GET /path/to/resource?query=value\r\nHost: localhost\r\nX-Custom-Header: value\r\n\r\nBody Text", + ] + + for test in requestBytes: + with self.subTest(): + request = Request(test) + response = handler.handleRequest(request) + + if response is None: + self.fail("Response was None.") + + expectedResponse = f"HTTP/1.1 200 OK\r\nContent-Length: {len(test)}\r\n\r\n{test.decode('iso-8859-1')}".encode("iso-8859-1") + + self.assertEqual(expectedResponse, response.build()) diff --git a/tests/testSocketConnection.py b/tests/testSocketConnection.py new file mode 100644 index 0000000..18899c5 --- /dev/null +++ b/tests/testSocketConnection.py @@ -0,0 +1,100 @@ +from queue import Queue +from socket import socket +from threading import Thread +from unittest import TestCase + +from protoserver.config import Config +from protoserver.servers.socket import SocketConnection, SocketServer + + +class testSocketConnection(TestCase): + + def setUp(self) -> None: + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "HOSTNAME": "localhost", "PORT": "80"}) + self.s = SocketServer(c) + self.c: Queue[SocketConnection] = Queue() + + self.clientConns: list[socket] = [] + + self.serverThread = Thread(target = self.runServer) + self.serverThread.start() + + def runServer(self) -> None: + while True: + conn = self.s.accept() + if conn is None: + break + self.c.put(conn) + + def tearDown(self) -> None: + self.s.stop() + del self.s + del self.c + + self.serverThread.join() + del self.serverThread + + for conn in self.clientConns: + conn.close() + del self.clientConns + + + def getClientConn(self) -> socket: + clientConn = socket() + self.clientConns.append(clientConn) + clientConn.connect(("localhost", 80)) + return clientConn + + def testAccept(self) -> None: + clientConn = self.getClientConn() + + serverConn = self.c.get() + + def testRecv(self) -> None: + clientConn = self.getClientConn() + + serverConn = self.c.get() + + clientConn.send(b"Request Bytes") + self.assertEqual(serverConn.recv(), b"Request Bytes") + + def testSend(self) -> None: + clientConn = self.getClientConn() + + serverConn = self.c.get() + + serverConn.send(b"Response Bytes") + self.assertEqual(clientConn.recv(8192), b"Response Bytes") + + def testClientClose(self) -> None: + clientConn = self.getClientConn() + + serverConn = self.c.get() + + clientConn.close() + self.assertEqual(serverConn.recv(), b"") + + def testServerClose(self) -> None: + clientConn = self.getClientConn() + + serverConn = self.c.get() + + serverConn.close() + self.assertEqual(clientConn.recv(8192), b"") + + def testMultipleConnections(self) -> None: + clientConn1 = self.getClientConn() + clientConn2 = self.getClientConn() + + serverConn1 = self.c.get() + serverConn2 = self.c.get() + + clientConn1.send(b"Request Bytes 1") + clientConn2.send(b"Request Bytes 2") + serverConn1.send(b"Response Bytes 1") + serverConn2.send(b"Response Bytes 2") + + self.assertEqual(clientConn1.recv(8192), b"Response Bytes 1") + self.assertEqual(clientConn2.recv(8192), b"Response Bytes 2") + self.assertEqual(serverConn1.recv(), b"Request Bytes 1") + self.assertEqual(serverConn2.recv(), b"Request Bytes 2") \ No newline at end of file diff --git a/tests/testSocketServer.py b/tests/testSocketServer.py new file mode 100644 index 0000000..33b749d --- /dev/null +++ b/tests/testSocketServer.py @@ -0,0 +1,79 @@ +from queue import Queue +from socket import socket +from threading import Thread +from unittest import TestCase + +from protoserver.config import Config +from protoserver.servers.socket import SocketConnection, SocketServer + +class testSocketServer(TestCase): + + def setUp(self) -> None: + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "HOSTNAME": "localhost", "PORT": "80"}) + self.s = SocketServer(c) + self.c: Queue[SocketConnection] = Queue() + + self.clientConns: list[socket] = [] + + self.serverThread = Thread(target = self.runServer) + self.serverThread.start() + + def runServer(self) -> None: + while True: + conn = self.s.accept() + if conn is None: + break + self.c.put(conn) + + def tearDown(self) -> None: + self.s.stop() + del self.s + del self.c + + self.serverThread.join() + del self.serverThread + + for conn in self.clientConns: + conn.close() + del self.clientConns + + def getClientConn(self) -> socket: + clientConn = socket() + self.clientConns.append(clientConn) + clientConn.connect(("localhost", 80)) + return clientConn + + def testInvalidConfig(self) -> None: + + with self.subTest("A"): + c = Config.getMinimalValidConfig({"SERVER": "SocketServer"}) + self.assertRaises(KeyError, lambda: SocketServer(c)) + + with self.subTest("B"): + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "HOSTNAME": "localhost"}) + self.assertRaises(KeyError, lambda: SocketServer(c)) + + with self.subTest("C"): + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "PORT": "22"}) + self.assertRaises(KeyError, lambda: SocketServer(c)) + + with self.subTest("D"): + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "HOSTNAME": "localhost", "PORT": "abc"}) + self.assertRaises(ValueError, lambda: SocketServer(c)) + + with self.subTest("E"): + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "HOSTNAME": "localhost", "PORT": "-1"}) + self.assertRaises(OverflowError, lambda: SocketServer(c)) + + with self.subTest("F"): + c = Config.getMinimalValidConfig({"SERVER": "SocketServer", "HOSTNAME": "localhost", "PORT": "65536"}) + self.assertRaises(OverflowError, lambda: SocketServer(c)) + + def testServerStop(self) -> None: + clientConn = self.getClientConn() + + self.c.get() + + self.s.stop() + + self.assertEqual(clientConn.recv(8192), b"") \ No newline at end of file