From 36399ccfd22bfffd8d178fbab09a31875fe426fa Mon Sep 17 00:00:00 2001 From: John Andersen Date: Sat, 4 Nov 2023 15:31:49 +0100 Subject: [PATCH] TLS connect fail to correct port Signed-off-by: John Andersen --- .../federation_activitypub_bovine.py | 2 +- tests/test_cli.py | 39 ++++++++++++++----- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/scitt_emulator/federation_activitypub_bovine.py b/scitt_emulator/federation_activitypub_bovine.py index 3432a061..f8b1d3a0 100644 --- a/scitt_emulator/federation_activitypub_bovine.py +++ b/scitt_emulator/federation_activitypub_bovine.py @@ -368,4 +368,4 @@ async def loop(client_name, client_config, handlers): except Exception as e: logger.exception("Something went wrong for %s", client_name) logger.exception(e) - await asyncio.sleep(60) + await asyncio.sleep(1) diff --git a/tests/test_cli.py b/tests/test_cli.py index 84cf6238..095c4df3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -29,19 +29,25 @@ old_create_sockets = hypercorn.config.Config.create_sockets -def socket_getaddrinfo_map_service_ports(services, host, *args, **kwargs): - # Map f"scitt.{handle_name}.example.com" to various local ports - if "scitt." not in host: - return old_socket_getaddrinfo(host, *args, **kwargs) - _, handle_name, _, _ = host.split(".") +def load_services_from_services_path(services): if isinstance(services, (str, pathlib.Path)): services_path = pathlib.Path(services) + if not services_path.exists(): + return old_socket_getaddrinfo(host, *args, **kwargs) services_content = services_path.read_text() services_dict = json.loads(services_content) services = { - handle_name: types.SimpleNameSpace(**service_dict) - for handle_name, service_dict in service_dict.items() + handle_name: types.SimpleNamespace(**service_dict) + for handle_name, service_dict in services_dict.items() } + return services + +def socket_getaddrinfo_map_service_ports(services, host, *args, **kwargs): + # Map f"scitt.{handle_name}.example.com" to various local ports + if "scitt." not in host: + return old_socket_getaddrinfo(host, *args, **kwargs) + _, handle_name, _, _ = host.split(".") + services = load_services_from_services_path(services) return [ ( socket.AF_INET, @@ -91,10 +97,23 @@ def __exit__(self, *args): def server_process(app, addr_queue, services): try: class MockResolver(aiohttp.resolver.DefaultResolver): - async def resolve(self, *args, **kwargs): + async def resolve(self, host, *args, **kwargs): nonlocal services - print("MockResolver.getaddrinfo") - return socket_getaddrinfo_map_service_ports(services, *args, **kwargs) + if "scitt." not in host: + return old_socket_getaddrinfo(host, *args, **kwargs) + _, handle_name, _, _ = host.split(".") + services = load_services_from_services_path(services) + return [ + { + "hostname": host, + "host": "127.0.0.1", + "port": services[handle_name].port, + "family": socket.AF_INET, + "proto": socket.SOCK_STREAM, + "flags": socket.AI_ADDRCONFIG, + } + ] + with contextlib.ExitStack() as exit_stack: exit_stack.enter_context( unittest.mock.patch(