Skip to content

Commit

Permalink
MockResolver.getaddrinfo
Browse files Browse the repository at this point in the history
Signed-off-by: John Andersen <[email protected]>
  • Loading branch information
pdxjohnny committed Nov 4, 2023
1 parent 88b82d2 commit 836bc08
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 39 deletions.
77 changes: 76 additions & 1 deletion scitt_emulator/federation_activitypub_bovine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def initialize_service(self):

# Run client handlers
async def mechanical_bull_loop(config):
from mechanical_bull.event_loop import loop
# from mechanical_bull.event_loop import loop
from mechanical_bull.handlers import load_handlers, build_handler

async with asyncio.TaskGroup() as taskgroup:
Expand Down Expand Up @@ -294,3 +294,78 @@ async def federate_created_entries(
print(f"End of messages in outbox, total: {count_messages}")
except:
logger.error(traceback.format_exc())

import asyncio

import bovine
import json

import logging

from mechanical_bull.handlers import HandlerEvent, call_handler_compat


async def handle_connection(client: bovine.BovineClient, handlers: list):
print("handle_connection")
event_source = await client.event_source()
print(event_source )
logger.info("Connected")
for handler in handlers:
await call_handler_compat(
handler,
client,
None,
handler_event=HandlerEvent.OPENED,
)
async for event in event_source:
if not event:
return
if event and event.data:
data = json.loads(event.data)

for handler in handlers:
await call_handler_compat(
handler,
client,
data,
handler_event=HandlerEvent.DATA,
)
for handler in handlers:
await call_handler_compat(
handler,
client,
None,
handler_event=HandlerEvent.CLOSED,
)


async def handle_connection_with_reconnect(
client: bovine.BovineClient,
handlers: list,
client_name: str = "BovineClient",
wait_time: int = 10,
):
while True:
await handle_connection(client, handlers)
logger.info(
"Disconnected from server for %s, reconnecting in %d seconds",
client_name,
wait_time,
)
await asyncio.sleep(wait_time)


async def loop(client_name, client_config, handlers):
while True:
try:
print(client_name)
pprint.pprint(client_config)
async with bovine.BovineClient(**client_config) as client:
print("client:", client)
await handle_connection_with_reconnect(
client, handlers, client_name=client_name
)
except Exception as e:
logger.exception("Something went wrong for %s", client_name)
logger.exception(e)
await asyncio.sleep(60)
94 changes: 79 additions & 15 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
# Licensed under the MIT License.
import os
import json
import types
import socket
import pathlib
import aiohttp.resolver
import functools
import threading
import traceback
import contextlib
import unittest.mock
import multiprocessing
import pytest
Expand All @@ -20,21 +25,47 @@
content_type = "application/json"
payload = '{"foo": "bar"}'

old_socket_getaddrinfo = socket.getaddrinfo
old_create_sockets = hypercorn.config.Config.create_sockets


def socket_getaddrinfo_map_service_ports(services, host, *args, **kwargs):
# Map f"scitt.{handle_name}.example.com" to various local ports
if "scitt." not in host:
return old_socket_getaddrinfo(host, *args, **kwargs)
_, handle_name, _, _ = host.split(".")
if isinstance(services, (str, pathlib.Path)):
services_path = pathlib.Path(services)
services_content = services_path.read_text()
services_dict = json.loads(services_content)
services = {
handle_name: types.SimpleNameSpace(**service_dict)
for handle_name, service_dict in service_dict.items()
}
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("127.0.0.1", services[handle_name].port),
)
]


def execute_cli(argv):
return cli.main([str(v) for v in argv])


class Service:
def __init__(self, config, create_flask_app=None):
def __init__(self, config, create_flask_app=None, services=None):
self.config = config
self.create_flask_app = (
create_flask_app
if create_flask_app is not None
else server.create_flask_app
)
self.services = services

def __enter__(self):
app = self.create_flask_app(self.config)
Expand All @@ -43,7 +74,8 @@ def __enter__(self):
self.host = "127.0.0.1"
addr_queue = multiprocessing.Queue()
self.process = multiprocessing.Process(name="server", target=self.server_process,
args=(app, addr_queue,))
args=(app, addr_queue,
self.services))
self.process.start()
self.host = addr_queue.get(True)
self.port = addr_queue.get(True)
Expand All @@ -56,20 +88,52 @@ def __exit__(self, *args):
self.process.join()

@staticmethod
def server_process(app, addr_queue):
class MockConfig(hypercorn.config.Config):
def create_sockets(self, *args, **kwargs):
sockets = old_create_sockets(self, *args, **kwargs)
server_name, server_port = sockets.insecure_sockets[0].getsockname()
addr_queue.put(server_name)
addr_queue.put(server_port)
return sockets

def server_process(app, addr_queue, services):
try:
with unittest.mock.patch(
"quart.app.HyperConfig",
side_effect=MockConfig,
):
class MockResolver(aiohttp.resolver.DefaultResolver):
async def resolve(self, *args, **kwargs):
nonlocal services
print("MockResolver.getaddrinfo")
return socket_getaddrinfo_map_service_ports(services, *args, **kwargs)
with contextlib.ExitStack() as exit_stack:
exit_stack.enter_context(
unittest.mock.patch(
"aiohttp.connector.DefaultResolver",
side_effect=MockResolver,
)
)
class MockConfig(hypercorn.config.Config):
def create_sockets(self, *args, **kwargs):
sockets = old_create_sockets(self, *args, **kwargs)
server_name, server_port = sockets.insecure_sockets[0].getsockname()
addr_queue.put(server_name)
addr_queue.put(server_port)
# Ensure that connect calls to them resolve as we want
exit_stack.enter_context(
unittest.mock.patch(
"socket.getaddrinfo",
wraps=functools.partial(
socket_getaddrinfo_map_service_ports,
services,
)
)
)
# exit_stack.enter_context(
# unittest.mock.patch(
# "asyncio.base_events.BaseEventLoop.getaddrinfo",
# wraps=make_loop_getaddrinfo_map_service_ports(
# services,
# )
# )
# )
return sockets

exit_stack.enter_context(
unittest.mock.patch(
"quart.app.HyperConfig",
side_effect=MockConfig,
)
)
app.run(port=0)
except:
traceback.print_exc()
Expand Down
41 changes: 18 additions & 23 deletions tests/test_federation_activitypub_bovine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
content_type,
payload,
execute_cli,
socket_getaddrinfo_map_service_ports,
)
from .test_docs import (
docutils_recursively_extract_nodes,
Expand All @@ -42,23 +43,6 @@
docs_dir = repo_root.joinpath("docs")
allowlisted_issuer = "did:web:example.org"

old_socket_getaddrinfo = socket.getaddrinfo

def socket_getaddrinfo_map_service_ports(services, host, *args, **kwargs):
# Map f"scitt.{handle_name}.example.com" to various local ports
if "scitt." not in host:
return old_socket_getaddrinfo(host, *args, **kwargs)
_, handle_name, _, _ = host.split(".")
return [
(
socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
("127.0.0.1", services[handle_name].port),
)
]


def test_docs_federation_activitypub_bovine(tmp_path):
claim_path = tmp_path / "claim.cose"
Expand All @@ -77,17 +61,18 @@ def test_docs_federation_activitypub_bovine(tmp_path):
tmp_path.joinpath(name).write_text(content)

services = {}
services_path = tmp_path / "services.json"
for handle_name, following in {
"bob": {
"alice": {
"actor_id": "[email protected]",
},
},
"alice": {
"bob": {
"actor_id": "[email protected]",
},
},
# "alice": {
# "bob": {
# "actor_id": "[email protected]",
# },
# },
}.items():
middleware_config_path = (
tmp_path
Expand All @@ -114,7 +99,8 @@ def test_docs_federation_activitypub_bovine(tmp_path):
"workspace": tmp_path / handle_name / "workspace",
"error_rate": 0,
"use_lro": False,
}
},
services=services_path,
)

with contextlib.ExitStack() as exit_stack:
Expand All @@ -131,6 +117,15 @@ def test_docs_federation_activitypub_bovine(tmp_path):
# Start all the services
for handle_name, service in services.items():
services[handle_name] = exit_stack.enter_context(service)
# Serialize services
services_path.write_text(
json.dumps(
{
handle_name: {"port": service.port}
for handle_name, service in services.items()
}
)
)
# Test of resolution
assert (
socket.getaddrinfo(f"scitt.{handle_name}.example.com", 0)[0][-1][-1]
Expand Down

0 comments on commit 836bc08

Please sign in to comment.