diff --git a/src/graphql/execution/subscribe.py b/src/graphql/execution/subscribe.py index 992fb051..358487ec 100644 --- a/src/graphql/execution/subscribe.py +++ b/src/graphql/execution/subscribe.py @@ -145,18 +145,8 @@ async def create_source_event_stream( return ExecutionResult(data=None, errors=context) try: - event_stream = await execute_subscription(context) - - # Assert field returned an event stream, otherwise yield an error. - if not isinstance(event_stream, AsyncIterable): - raise TypeError( - "Subscription field must return AsyncIterable." - f" Received: {inspect(event_stream)}." - ) - return event_stream - + return await execute_subscription(context) except GraphQLError as error: - # Report it as an ExecutionResult, containing only errors and no data. return ExecutionResult(data=None, errors=[error]) @@ -207,6 +197,13 @@ async def execute_subscription(context: ExecutionContext) -> AsyncIterable[Any]: if isinstance(event_stream, Exception): raise event_stream + # Assert field returned an event stream, otherwise yield an error. + if not isinstance(event_stream, AsyncIterable): + raise GraphQLError( + "Subscription field must return AsyncIterable." + f" Received: {inspect(event_stream)}." + ) + return event_stream except Exception as error: raise located_error(error, field_nodes, path.as_list()) diff --git a/tests/execution/test_subscribe.py b/tests/execution/test_subscribe.py index 24655915..94a4c4f9 100644 --- a/tests/execution/test_subscribe.py +++ b/tests/execution/test_subscribe.py @@ -354,11 +354,16 @@ async def should_pass_through_unexpected_errors_thrown_in_subscribe(): @mark.asyncio @mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning") async def throws_an_error_if_subscribe_does_not_return_an_iterator(): - with raises(TypeError) as exc_info: - await subscribe_with_bad_fn(lambda _obj, _info: "test") - - assert str(exc_info.value) == ( - "Subscription field must return AsyncIterable. Received: 'test'." + assert await subscribe_with_bad_fn(lambda _obj, _info: "test") == ( + None, + [ + { + "message": "Subscription field must return AsyncIterable." + " Received: 'test'.", + "locations": [(1, 16)], + "path": ["foo"], + } + ], ) @mark.asyncio