diff --git a/src/webgwas_backend/test_main.py b/src/webgwas_backend/test_main.py index e865e57..1273744 100644 --- a/src/webgwas_backend/test_main.py +++ b/src/webgwas_backend/test_main.py @@ -19,7 +19,7 @@ WebGWASResponse, WebGWASResult, ) -from webgwas_backend.worker import Worker +from webgwas_backend.worker import TestWorker def setup_db(session: Session, rootdir: pathlib.Path): @@ -137,7 +137,7 @@ def client(): n_workers=2, indirect_gwas=IndirectGWASSettings(batch_size=10000), ) - worker = Worker(settings) + worker = TestWorker(settings) def get_worker_override(): return worker @@ -146,7 +146,6 @@ def get_worker_override(): app.dependency_overrides[get_worker] = get_worker_override with TestClient(app) as client: yield client - worker.shutdown() def test_get_cohorts(client): diff --git a/src/webgwas_backend/worker.py b/src/webgwas_backend/worker.py index 5c6a355..200ceb1 100644 --- a/src/webgwas_backend/worker.py +++ b/src/webgwas_backend/worker.py @@ -59,3 +59,24 @@ def get_results(self, request_id: str) -> WebGWASResult: def shutdown(self): self.executor.shutdown(wait=True, cancel_futures=True) self.manager.shutdown() + + +class TestWorker(Worker): + __test__ = False + + def __init__(self, settings: Settings): + self.s3_dry_run = settings.dry_run + self.s3_bucket = settings.s3_bucket + self.batch_size = settings.indirect_gwas.batch_size + self.id_to_result: dict[str, WebGWASResult] = {} + + def submit(self, request: WebGWASRequestID): + logger.info(f"Submitting request: {request}") + result = self.handle_request( + request, self.s3_dry_run, self.s3_bucket, self.batch_size + ) + self.id_to_result[request.id] = result + logger.info(f"Queued request: {request.id}") + + def get_results(self, request_id: str) -> WebGWASResult: + return self.id_to_result[request_id]