diff --git a/tests/test_cli.py b/tests/test_cli.py index 5dd16120..61ab70ee 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -2,12 +2,17 @@ # Licensed under the MIT License. import os import json +import socket import threading +import traceback +import unittest.mock +import multiprocessing import pytest import jwt import jwcrypto +import hypercorn.config +# TODO Remove in favor of Quart from flask import Flask, jsonify -# from werkzeug.serving import make_server from scitt_emulator import cli, server from scitt_emulator.oidc import OIDCAuthMiddleware @@ -35,71 +40,39 @@ def __enter__(self): if hasattr(app, "service_parameters_path"): self.service_parameters_path = app.service_parameters_path self.host = "127.0.0.1" - # self.server = make_server(self.host, 0, app) - # TODO Wrapper on run pass port via Queue, sys audithook or config.log - # for app.run (hypercorn asyncio.run.worker_serve) - def mythread(): - def capture_port(event, *args): - print("event", event) - if event not in ("socket.bind", "socket.__new__"): - return - socket = args[0][0] - print("socket", socket) - try: - breakpoint() - print("socket", socket.getsockname()) - except: - import traceback - traceback.print_exc() - # sys.addaudithook(capture_port) - import socket - - old_socket_bind = socket.socket.bind - - def socket_bind(*args, **kwargs): - print(args, kwargs) - return old_socket_bind(*args, **kwargs) - - import hypercorn.config - - old_create_sockets = hypercorn.config.Config.create_sockets - - class MockConfig(hypercorn.config.Config): - def create_sockets(self, *args, **kwargs): - sockets = old_create_sockets(self, *args, **kwargs) - port = sockets.insecure_sockets[0].getsockname()[1] - return sockets - - try: - import unittest.mock - with unittest.mock.patch( - "quart.app.HyperConfig", - side_effect=MockConfig, - ): - - print("running...") - print() - print(app.run(port=0)) - except: - import traceback - traceback.print_exc() - - import multiprocessing - self.thread = multiprocessing.Process(name="server", target=mythread) - self.thread.start() - - import time - time.sleep(60) - sys.exit(0) - - port = self.server.port - self.url = f"http://{self.host}:{port}" + addr_queue = multiprocessing.Queue() + self.process = multiprocessing.Process(name="server", target=self.server_process, + args=(addr_queue,)) + self.process.start() + self.host = addr_queue.get(True) + self.port = addr_queue.get(True) + self.url = f"http://{self.host}:{self.port}" app.url = self.url return self def __exit__(self, *args): - self.server.shutdown() - self.thread.join() + self.process.join() + + @staicmethod + def server_process(addr_queue): + old_create_sockets = hypercorn.config.Config.create_sockets + + class MockConfig(hypercorn.config.Config): + def create_sockets(self, *args, **kwargs): + sockets = old_create_sockets(self, *args, **kwargs) + server_name, server_port = sockets.insecure_sockets[0].getsockname() + addr_queue.put(server_name) + addr_queue.put(server_port) + return sockets + + try: + with unittest.mock.patch( + "quart.app.HyperConfig", + side_effect=MockConfig, + ): + app.run(port=0) + except: + traceback.print_exc() @pytest.mark.parametrize( "use_lro", [True, False],