From 18690a84d9e61fd33cb89052a5b11f44bc9f5127 Mon Sep 17 00:00:00 2001 From: alangenfeld Date: Thu, 22 Jun 2023 16:38:25 -0500 Subject: [PATCH] [graphql] support async resolvers --- python_modules/dagit/dagit/graphql.py | 28 ++++++++++++------- .../dagit/dagit_tests/webserver/test_app.py | 10 +++++++ .../dagster_graphql/schema/test.py | 7 +++++ 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/python_modules/dagit/dagit/graphql.py b/python_modules/dagit/dagit/graphql.py index 4ba4dc05bc14c..4ba1d2a1c89e4 100644 --- a/python_modules/dagit/dagit/graphql.py +++ b/python_modules/dagit/dagit/graphql.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from asyncio import Task, get_event_loop +from asyncio import Task, get_event_loop, run from enum import Enum from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union, cast @@ -225,15 +225,23 @@ async def execute_graphql_request( variables: Optional[Dict[str, Any]], operation_name: Optional[str], ) -> ExecutionResult: - # use run_in_threadpool since underlying schema is sync - return await run_in_threadpool( - self._graphql_schema.execute, - query, - variables=variables, - operation_name=operation_name, - context=self.make_request_context(request), - middleware=self._graphql_middleware, - ) + # run each query in a separate thread, as much of the schema is sync/blocking + # use execute_async to allow async resolvers to facilitate parallel of io wait + + request_context = self.make_request_context(request) + + def _graphql_request(): + return run( + self._graphql_schema.execute_async( + query, + variables=variables, + operation_name=operation_name, + context=request_context, + middleware=self._graphql_middleware, + ) + ) + + return await run_in_threadpool(_graphql_request) async def execute_graphql_subscription( self, diff --git a/python_modules/dagit/dagit_tests/webserver/test_app.py b/python_modules/dagit/dagit_tests/webserver/test_app.py index 8eda9d33674d9..e9896abd5fa2e 100644 --- a/python_modules/dagit/dagit_tests/webserver/test_app.py +++ b/python_modules/dagit/dagit_tests/webserver/test_app.py @@ -268,3 +268,13 @@ def test_download_compute(instance, test_client: TestClient): response = test_client.get(f"/download/{run_id}/jonx/stdout") assert response.status_code == 404 + + +def test_async(test_client: TestClient): + response = test_client.post( + "/graphql", + params={"query": "{test{asyncString}}"}, + ) + assert response.status_code == 200, response.text + result = response.json() + assert result["data"]["test"]["asyncString"] == "slept", result diff --git a/python_modules/dagster-graphql/dagster_graphql/schema/test.py b/python_modules/dagster-graphql/dagster_graphql/schema/test.py index 01523d51701b9..035557985d601 100644 --- a/python_modules/dagster-graphql/dagster_graphql/schema/test.py +++ b/python_modules/dagster-graphql/dagster_graphql/schema/test.py @@ -1,3 +1,5 @@ +import asyncio + import graphene @@ -6,6 +8,11 @@ class Meta: name = "TestFields" alwaysException = graphene.String() + asyncString = graphene.String() def resolve_alwaysException(self, _): raise Exception("as advertised") + + async def resolve_asyncString(self, _): + await asyncio.sleep(0) + return "slept"