Skip to content

Commit

Permalink
Change incorrect subscribe return type to a GraphQLError rather than …
Browse files Browse the repository at this point in the history
…systems error

Replicates graphql/graphql-js@ea1894a
  • Loading branch information
Cito committed Nov 3, 2022
1 parent 47ecdb3 commit 8d7d0f6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
19 changes: 8 additions & 11 deletions src/graphql/execution/subscribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,8 @@ async def create_source_event_stream(
return ExecutionResult(data=None, errors=context)

try:
event_stream = await execute_subscription(context)

# Assert field returned an event stream, otherwise yield an error.
if not isinstance(event_stream, AsyncIterable):
raise TypeError(
"Subscription field must return AsyncIterable."
f" Received: {inspect(event_stream)}."
)
return event_stream

return await execute_subscription(context)
except GraphQLError as error:
# Report it as an ExecutionResult, containing only errors and no data.
return ExecutionResult(data=None, errors=[error])


Expand Down Expand Up @@ -207,6 +197,13 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]:
if isinstance(event_stream, Exception):
raise event_stream

# Assert field returned an event stream, otherwise yield an error.
if not isinstance(event_stream, AsyncIterable):
raise GraphQLError(
"Subscription field must return AsyncIterable."
f" Received: {inspect(event_stream)}."
)

return event_stream
except Exception as error:
raise located_error(error, field_nodes, path.as_list())
15 changes: 10 additions & 5 deletions tests/execution/test_subscribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,16 @@ async def should_pass_through_unexpected_errors_thrown_in_subscribe():
@mark.asyncio
@mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
async def throws_an_error_if_subscribe_does_not_return_an_iterator():
with raises(TypeError) as exc_info:
await subscribe_with_bad_fn(lambda _obj, _info: "test")

assert str(exc_info.value) == (
"Subscription field must return AsyncIterable. Received: 'test'."
assert await subscribe_with_bad_fn(lambda _obj, _info: "test") == (
None,
[
{
"message": "Subscription field must return AsyncIterable."
" Received: 'test'.",
"locations": [(1, 16)],
"path": ["foo"],
}
],
)

@mark.asyncio
Expand Down

0 comments on commit 8d7d0f6

Please sign in to comment.