Skip to content

Commit

Permalink
Use is_awaitable more consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Jun 4, 2023
1 parent a46cc84 commit 1a62f30
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 34 deletions.
9 changes: 4 additions & 5 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from asyncio import Event, as_completed, ensure_future, gather, shield, sleep, wait_for
from collections.abc import Mapping
from contextlib import suppress
from inspect import isawaitable
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -1646,7 +1645,7 @@ def map_source_to_response(
async def callback(payload: Any) -> AsyncGenerator:
result = execute_impl(self.build_per_event_execution_context(payload))
return ensure_async_iterable(
await result if isawaitable(result) else result # type: ignore
await result if self.is_awaitable(result) else result # type: ignore
)

return flatten_async_iterable(map_async_iterable(result_or_stream, callback))
Expand Down Expand Up @@ -2124,7 +2123,7 @@ async def await_result() -> Any:


def assume_not_awaitable(_value: Any) -> bool:
"""Replacement for isawaitable if everything is assumed to be synchronous."""
"""Replacement for is_awaitable if everything is assumed to be synchronous."""
return False


Expand Down Expand Up @@ -2172,10 +2171,10 @@ def execute_sync(
)

# Assert that the execution was synchronous.
if isawaitable(result) or isinstance(
if default_is_awaitable(result) or isinstance(
result, ExperimentalIncrementalExecutionResults
):
if isawaitable(result):
if default_is_awaitable(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")

Expand Down
8 changes: 4 additions & 4 deletions src/graphql/graphql.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from asyncio import ensure_future
from inspect import isawaitable
from typing import Any, Awaitable, Callable, Dict, Optional, Type, Union, cast

from .error import GraphQLError
from .execution import ExecutionContext, ExecutionResult, Middleware, execute
from .language import Source, parse
from .pyutils import AwaitableOrValue
from .pyutils import is_awaitable as default_is_awaitable
from .type import (
GraphQLFieldResolver,
GraphQLSchema,
Expand Down Expand Up @@ -92,14 +92,14 @@ async def graphql(
is_awaitable,
)

if isawaitable(result):
if default_is_awaitable(result):
return await cast(Awaitable[ExecutionResult], result)

return cast(ExecutionResult, result)


def assume_not_awaitable(_value: Any) -> bool:
"""Replacement for isawaitable if everything is assumed to be synchronous."""
"""Replacement for is_awaitable if everything is assumed to be synchronous."""
return False


Expand Down Expand Up @@ -145,7 +145,7 @@ def graphql_sync(
)

# Assert that the execution was synchronous.
if isawaitable(result):
if default_is_awaitable(result):
ensure_future(cast(Awaitable[ExecutionResult], result)).cancel()
raise RuntimeError("GraphQL execution failed to complete synchronously.")

Expand Down
6 changes: 3 additions & 3 deletions src/graphql/pyutils/is_awaitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@


def is_awaitable(value: Any) -> TypeGuard[Awaitable]:
"""Return true if object can be passed to an ``await`` expression.
"""Return True if object can be passed to an ``await`` expression.
Instead of testing if the object is an instance of abc.Awaitable, it checks
the existence of an `__await__` attribute. This is much faster.
Instead of testing whether the object is an instance of abc.Awaitable, we
check the existence of an `__await__` attribute. This is much faster.
"""
return (
# check for coroutine objects
Expand Down
5 changes: 3 additions & 2 deletions src/graphql/pyutils/simple_pub_sub.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations # Python < 3.10

from asyncio import Future, Queue, create_task, get_running_loop, sleep
from inspect import isawaitable
from typing import Any, AsyncIterator, Callable, Optional, Set

from .is_awaitable import is_awaitable


__all__ = ["SimplePubSub", "SimplePubSubIterator"]

Expand All @@ -25,7 +26,7 @@ def emit(self, event: Any) -> bool:
"""Emit an event."""
for subscriber in self.subscribers:
result = subscriber(event)
if isawaitable(result):
if is_awaitable(result):
create_task(result) # type: ignore
return bool(self.subscribers)

Expand Down
4 changes: 2 additions & 2 deletions tests/benchmarks/test_async_iterable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from inspect import isawaitable

from graphql import ExecutionResult, build_schema, execute, parse
from graphql.pyutils import is_awaitable


schema = build_schema("type Query { listField: [String] }")
Expand All @@ -18,7 +18,7 @@ async def listField(info_):

async def execute_async() -> ExecutionResult:
result = execute(schema, document, Data())
assert isawaitable(result)
assert is_awaitable(result)
return await result


Expand Down
4 changes: 2 additions & 2 deletions tests/execution/test_abstract.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from inspect import isawaitable
from typing import Any, NamedTuple, Optional

from pytest import mark

from graphql.execution import ExecutionResult, execute, execute_sync
from graphql.language import parse
from graphql.pyutils import is_awaitable
from graphql.type import (
GraphQLBoolean,
GraphQLField,
Expand Down Expand Up @@ -43,7 +43,7 @@ async def execute_query(
result = (execute_sync if sync else execute)(
schema, document, root_value
) # type: ignore
if not sync and isawaitable(result):
if not sync and is_awaitable(result):
result = await result
assert isinstance(result, ExecutionResult)
return result
Expand Down
5 changes: 2 additions & 3 deletions tests/execution/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from inspect import isawaitable

from pytest import mark, raises

from graphql import graphql_sync
from graphql.execution import execute, execute_sync
from graphql.language import parse
from graphql.pyutils import is_awaitable
from graphql.type import GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString
from graphql.validation import validate

Expand Down Expand Up @@ -57,7 +56,7 @@ def does_not_return_an_awaitable_if_mutation_fields_are_all_synchronous():
async def returns_an_awaitable_if_any_field_is_asynchronous():
doc = "query Example { syncField, asyncField }"
result = execute(schema, parse(doc), "rootValue")
assert isawaitable(result)
assert is_awaitable(result)
assert await result == (
{"syncField": "rootValue", "asyncField": "rootValue"},
None,
Expand Down
11 changes: 5 additions & 6 deletions tests/pyutils/test_async_reduce.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from functools import reduce
from inspect import isawaitable

from pytest import mark

from graphql.pyutils import async_reduce
from graphql.pyutils import async_reduce, is_awaitable


def describe_async_reduce():
Expand All @@ -25,7 +24,7 @@ def callback(accumulator, current_value):

values = ["bar", "baz"]
result = async_reduce(callback, values, "foo")
assert not isawaitable(result)
assert not is_awaitable(result)
assert result == "foo-bar-baz"

@mark.asyncio
Expand All @@ -38,7 +37,7 @@ def callback(accumulator, current_value):

values = ["bar", "baz"]
result = async_reduce(callback, values, async_initial_value())
assert isawaitable(result)
assert is_awaitable(result)
assert await result == "foo-bar-baz"

@mark.asyncio
Expand All @@ -48,7 +47,7 @@ async def async_callback(accumulator, current_value):

values = ["bar", "baz"]
result = async_reduce(async_callback, values, "foo")
assert isawaitable(result)
assert is_awaitable(result)
assert await result == "foo-bar-baz"

@mark.asyncio
Expand All @@ -60,5 +59,5 @@ async def async_callback(accumulator, current_value):
return accumulator * current_value

result = async_reduce(async_callback, range(6, 9), async_initial_value())
assert isawaitable(result)
assert is_awaitable(result)
assert await result == 42
7 changes: 3 additions & 4 deletions tests/pyutils/test_simple_pub_sub.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from asyncio import sleep
from inspect import isawaitable

from pytest import mark, raises

from graphql.pyutils import SimplePubSub
from graphql.pyutils import SimplePubSub, is_awaitable


def describe_simple_pub_sub():
Expand All @@ -22,9 +21,9 @@ async def subscribe_async_iterator_mock():

# Read ahead
i3 = await iterator.__anext__()
assert isawaitable(i3)
assert is_awaitable(i3)
i4 = await iterator.__anext__()
assert isawaitable(i4)
assert is_awaitable(i4)

# Publish
assert pubsub.emit("Coconut") is True
Expand Down
5 changes: 2 additions & 3 deletions tests/test_user_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from asyncio import create_task, sleep, wait
from collections import defaultdict
from enum import Enum
from inspect import isawaitable
from typing import Any, AsyncIterable, Dict, List, NamedTuple, Optional

from pytest import fixture, mark
Expand All @@ -29,7 +28,7 @@
parse,
subscribe,
)
from graphql.pyutils import SimplePubSub, SimplePubSubIterator
from graphql.pyutils import SimplePubSub, SimplePubSubIterator, is_awaitable


class User(NamedTuple):
Expand Down Expand Up @@ -157,7 +156,7 @@ async def subscribe_user(_root, info, id=None):
"""Subscribe to mutations of a specific user object or all user objects"""
async_iterator = info.context["registry"].event_iterator(id)
async for event in async_iterator:
yield await event if isawaitable(event) else event # pragma: no cover exit
yield await event if is_awaitable(event) else event # pragma: no cover exit


# noinspection PyShadowingBuiltins,PyUnusedLocal
Expand Down

0 comments on commit 1a62f30

Please sign in to comment.