Skip to content

Commit

Permalink
Allow injecting custom data to custom execution context (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Jan 25, 2025
1 parent c685d84 commit 3818653
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
28 changes: 14 additions & 14 deletions src/graphql/execution/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from asyncio import ensure_future, gather, shield, wait_for
from contextlib import suppress
from copy import copy
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -219,6 +220,7 @@ def build(
subscribe_field_resolver: GraphQLFieldResolver | None = None,
middleware: Middleware | None = None,
is_awaitable: Callable[[Any], bool] | None = None,
**custom_args: Any,
) -> list[GraphQLError] | ExecutionContext:
"""Build an execution context
Expand Down Expand Up @@ -292,24 +294,14 @@ def build(
IncrementalPublisher(),
middleware_manager,
is_awaitable,
**custom_args,
)

def build_per_event_execution_context(self, payload: Any) -> ExecutionContext:
"""Create a copy of the execution context for usage with subscribe events."""
return self.__class__(
self.schema,
self.fragments,
payload,
self.context_value,
self.operation,
self.variable_values,
self.field_resolver,
self.type_resolver,
self.subscribe_field_resolver,
self.incremental_publisher,
self.middleware_manager,
self.is_awaitable,
)
context = copy(self)
context.root_value = payload
return context

def execute_operation(
self, initial_result_record: InitialResultRecord
Expand Down Expand Up @@ -1709,6 +1701,7 @@ def execute(
middleware: Middleware | None = None,
execution_context_class: type[ExecutionContext] | None = None,
is_awaitable: Callable[[Any], bool] | None = None,
**custom_context_args: Any,
) -> AwaitableOrValue[ExecutionResult]:
"""Execute a GraphQL operation.
Expand Down Expand Up @@ -1741,6 +1734,7 @@ def execute(
middleware,
execution_context_class,
is_awaitable,
**custom_context_args,
)
if isinstance(result, ExecutionResult):
return result
Expand Down Expand Up @@ -1769,6 +1763,7 @@ def experimental_execute_incrementally(
middleware: Middleware | None = None,
execution_context_class: type[ExecutionContext] | None = None,
is_awaitable: Callable[[Any], bool] | None = None,
**custom_context_args: Any,
) -> AwaitableOrValue[ExecutionResult | ExperimentalIncrementalExecutionResults]:
"""Execute GraphQL operation incrementally (internal implementation).
Expand Down Expand Up @@ -1797,6 +1792,7 @@ def experimental_execute_incrementally(
subscribe_field_resolver,
middleware,
is_awaitable,
**custom_context_args,
)

# Return early errors if execution context failed.
Expand Down Expand Up @@ -2127,6 +2123,7 @@ def subscribe(
subscribe_field_resolver: GraphQLFieldResolver | None = None,
execution_context_class: type[ExecutionContext] | None = None,
middleware: MiddlewareManager | None = None,
**custom_context_args: Any,
) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]:
"""Create a GraphQL subscription.
Expand Down Expand Up @@ -2167,6 +2164,7 @@ def subscribe(
type_resolver,
subscribe_field_resolver,
middleware=middleware,
**custom_context_args,
)

# Return early errors if execution context failed.
Expand Down Expand Up @@ -2202,6 +2200,7 @@ def create_source_event_stream(
type_resolver: GraphQLTypeResolver | None = None,
subscribe_field_resolver: GraphQLFieldResolver | None = None,
execution_context_class: type[ExecutionContext] | None = None,
**custom_context_args: Any,
) -> AwaitableOrValue[AsyncIterable[Any] | ExecutionResult]:
"""Create source event stream
Expand Down Expand Up @@ -2238,6 +2237,7 @@ def create_source_event_stream(
field_resolver,
type_resolver,
subscribe_field_resolver,
**custom_context_args,
)

# Return early errors if execution context failed.
Expand Down
16 changes: 15 additions & 1 deletion tests/execution/test_customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def uses_a_custom_execution_context_class():
)

class TestExecutionContext(ExecutionContext):
def __init__(self, *args, **kwargs):
assert kwargs.pop("custom_arg", None) == "baz"
super().__init__(*args, **kwargs)

def execute_field(
self,
parent_type,
Expand All @@ -62,7 +66,12 @@ def execute_field(
)
return result * 2 # type: ignore

assert execute(schema, query, execution_context_class=TestExecutionContext) == (
assert execute(
schema,
query,
execution_context_class=TestExecutionContext,
custom_arg="baz",
) == (
{"foo": "barbar"},
None,
)
Expand Down Expand Up @@ -101,6 +110,10 @@ async def custom_foo():
@pytest.mark.asyncio
async def uses_a_custom_execution_context_class():
class TestExecutionContext(ExecutionContext):
def __init__(self, *args, **kwargs):
assert kwargs.pop("custom_arg", None) == "baz"
super().__init__(*args, **kwargs)

def build_resolve_info(self, *args, **kwargs):
resolve_info = super().build_resolve_info(*args, **kwargs)
resolve_info.context["foo"] = "bar"
Expand Down Expand Up @@ -132,6 +145,7 @@ def resolve_foo(message, _info):
document,
context_value={},
execution_context_class=TestExecutionContext,
custom_arg="baz",
)
assert isasyncgen(subscription)

Expand Down

0 comments on commit 3818653

Please sign in to comment.