From 791209f5574719962b102cdf32399262fd2f3cb3 Mon Sep 17 00:00:00 2001 From: Josh Warwick Date: Tue, 9 May 2023 21:52:55 +0100 Subject: [PATCH] Rejig concept to use middleware --- graphene_django/converter.py | 32 ++--------------------------- graphene_django/debug/middleware.py | 27 ++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 32 deletions(-) diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 760973f09..375d68312 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,7 +1,5 @@ from collections import OrderedDict from functools import singledispatch, wraps -from asyncio import get_running_loop -from asgiref.sync import sync_to_async from django.db import models from django.utils.encoding import force_str @@ -267,20 +265,7 @@ def dynamic_type(): if not _type: return - class CustomField(Field): - def wrap_resolve(self, parent_resolver): - resolver = super().wrap_resolve(parent_resolver) - - try: - get_running_loop() - except RuntimeError: - pass - else: - resolver = sync_to_async(resolver) - - return resolver - - return CustomField(_type, required=not field.null) + return Field(_type, required=not field.null) return Dynamic(dynamic_type) @@ -335,20 +320,7 @@ def dynamic_type(): if not _type: return - class CustomField(Field): - def wrap_resolve(self, parent_resolver): - resolver = super().wrap_resolve(parent_resolver) - - try: - get_running_loop() - except RuntimeError: - pass - else: - resolver = sync_to_async(resolver) - - return resolver - - return CustomField( + return Field( _type, description=get_django_field_description(field), required=not field.null, diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index d3052a14a..40a951a98 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -1,7 +1,7 @@ from django.db import connections -from promise import Promise - +from asgiref.sync import sync_to_async +import inspect from .sql.tracking import unwrap_cursor, wrap_cursor from .exception.formating import wrap_exception from .types import DjangoDebug @@ -69,3 +69,26 @@ def resolve(self, next, root, info, **args): return context.django_debug.on_resolve_error(e) context.django_debug.add_result(result) return result + + +class DjangoSyncRequiredMiddleware: + def resolve(self, next, root, info, **args): + parent_type = info.parent_type + + ## Anytime the parent is a DjangoObject type + # and we're resolving a sync field, we need to wrap it in a sync_to_async + if hasattr(parent_type, "graphene_type") and hasattr( + parent_type.graphene_type._meta, "model" + ): + if not inspect.iscoroutinefunction(next): + return sync_to_async(next)(root, info, **args) + + ## In addition, if we're resolving to a DjangoObject type + # we likely need to wrap it in a sync_to_async as well + if hasattr(info.return_type, "graphene_type") and hasattr( + info.return_type.graphene_type._meta, "model" + ): + if not info.is_awaitable(next): + return sync_to_async(next)(root, info, **args) + + return next(root, info, **args)