diff --git a/graphene_django/compat.py b/graphene_django/compat.py index b3d160a1..5056baa6 100644 --- a/graphene_django/compat.py +++ b/graphene_django/compat.py @@ -1,10 +1,12 @@ import sys +from collections.abc import Callable from pathlib import PurePath # For backwards compatibility, we import JSONField to have it available for import via # this compat module (https://github.com/graphql-python/graphene-django/issues/1428). # Django's JSONField is available in Django 3.2+ (the minimum version we support) -from django.db.models import JSONField +import django +from django.db import models class MissingType: @@ -34,7 +36,7 @@ def __init__(self, *args, **kwargs): ] ): - class ArrayField(JSONField): + class ArrayField(models.JSONField): def __init__(self, *args, **kwargs): if len(args) > 0: self.base_field = args[0] @@ -42,3 +44,41 @@ def __init__(self, *args, **kwargs): else: ArrayField = MissingType + + +try: + from django.utils.choices import normalize_choices +except ImportError: + + def normalize_choices(choices): + if isinstance(choices, type) and issubclass(choices, models.Choices): + choices = choices.choices + + if isinstance(choices, Callable): + choices = choices() + + # In restframework==3.15.0, choices are not passed + # as OrderedDict anymore, so it's safer to check + # for a dict + if isinstance(choices, dict): + choices = choices.items() + + return choices + + +def get_choices_as_class(choices_class): + if django.VERSION >= (5, 0): + return choices_class + else: + return choices_class.choices + + +def get_choices_as_callable(choices_class): + if django.VERSION >= (5, 0): + + def choices(): + return choices_class.choices + + return choices + else: + return choices_class.choices diff --git a/graphene_django/converter.py b/graphene_django/converter.py index 7eba22a1..12480a0d 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -1,5 +1,4 @@ import inspect -from collections.abc import Callable from functools import partial, singledispatch, wraps from django.db import models @@ -37,7 +36,7 @@ from graphql import assert_valid_name as assert_name from graphql.pyutils import register_description -from .compat import ArrayField, HStoreField, RangeField +from .compat import ArrayField, HStoreField, RangeField, normalize_choices from .fields import DjangoConnectionField, DjangoListField from .settings import graphene_settings from .utils.str_converters import to_const @@ -61,6 +60,24 @@ def wrapped_resolver(*args, **kwargs): return blank_field_wrapper(resolver) +class EnumValueField(Field): + def wrap_resolve(self, parent_resolver): + resolver = self.resolver or parent_resolver + + # create custom resolver + def enum_field_wrapper(func): + @wraps(func) + def wrapped_resolver(*args, **kwargs): + return_value = func(*args, **kwargs) + if isinstance(return_value, models.Choices): + return_value = return_value.value + return return_value + + return wrapped_resolver + + return enum_field_wrapper(resolver) + + def convert_choice_name(name): name = to_const(force_str(name)) try: @@ -72,15 +89,7 @@ def convert_choice_name(name): def get_choices(choices): converted_names = [] - if isinstance(choices, Callable): - choices = choices() - - # In restframework==3.15.0, choices are not passed - # as OrderedDict anymore, so it's safer to check - # for a dict - if isinstance(choices, dict): - choices = choices.items() - + choices = normalize_choices(choices) for value, help_text in choices: if isinstance(help_text, (tuple, list)): yield from get_choices(help_text) @@ -157,7 +166,7 @@ def convert_django_field_with_choices( converted = EnumCls( description=get_django_field_description(field), required=required - ).mount_as(BlankValueField) + ).mount_as(EnumValueField) else: converted = convert_django_field(field, registry) if registry is not None: diff --git a/graphene_django/forms/types.py b/graphene_django/forms/types.py index 0e311e5d..68ffa663 100644 --- a/graphene_django/forms/types.py +++ b/graphene_django/forms/types.py @@ -3,7 +3,7 @@ from graphene.types.inputobjecttype import InputObjectType from graphene.utils.str_converters import to_camel_case -from ..converter import BlankValueField +from ..converter import EnumValueField from ..types import ErrorType # noqa Import ErrorType for backwards compatibility from .mutation import fields_for_form @@ -57,11 +57,10 @@ def mutate(_root, _info, data): if ( object_type and name in object_type._meta.fields - and isinstance(object_type._meta.fields[name], BlankValueField) + and isinstance(object_type._meta.fields[name], EnumValueField) ): - # Field type BlankValueField here means that field + # Field type EnumValueField here means that field # with choices have been converted to enum - # (BlankValueField is using only for that task ?) setattr(cls, name, cls.get_enum_cnv_cls_instance(name, object_type)) elif ( object_type diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index ece1bb6d..d9b781ca 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -1,9 +1,16 @@ from django.db import models from django.utils.translation import gettext_lazy as _ +from graphene_django.compat import get_choices_as_callable, get_choices_as_class + CHOICES = ((1, "this"), (2, _("that"))) +class TypedChoice(models.IntegerChoices): + CHOICE_THIS = 1 + CHOICE_THAT = 2 + + class Person(models.Model): name = models.CharField(max_length=30) parent = models.ForeignKey( @@ -51,6 +58,15 @@ class Reporter(models.Model): email = models.EmailField() pets = models.ManyToManyField("self") a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True) + typed_choice = models.IntegerField( + choices=TypedChoice.choices, null=True, blank=True + ) + class_choice = models.IntegerField( + choices=get_choices_as_class(TypedChoice), null=True, blank=True + ) + callable_choice = models.IntegerField( + choices=get_choices_as_callable(TypedChoice), null=True, blank=True + ) objects = models.Manager() doe_objects = DoeReporterManager() fans = models.ManyToManyField(Person) diff --git a/graphene_django/tests/test_converter.py b/graphene_django/tests/test_converter.py index 2f8b1d51..05747566 100644 --- a/graphene_django/tests/test_converter.py +++ b/graphene_django/tests/test_converter.py @@ -25,7 +25,7 @@ ) from ..registry import Registry from ..types import DjangoObjectType -from .models import Article, Film, FilmDetails, Reporter +from .models import Article, Film, FilmDetails, Reporter, TypedChoice # from graphene.core.types.custom_scalars import DateTime, Time, JSONString @@ -475,3 +475,44 @@ def resolve_reporter(root, info): assert result.data == { "reporter": {"firstName": "Bridget", "aChoice": None}, } + + +def test_typed_choice_value(): + """Test that choice fields with blank values work""" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + fields = ("typed_choice", "class_choice", "callable_choice") + + class Query(graphene.ObjectType): + reporter = graphene.Field(ReporterType) + + def resolve_reporter(root, info): + return Reporter( + typed_choice=TypedChoice.CHOICE_THIS, + class_choice=TypedChoice.CHOICE_THAT, + callable_choice=TypedChoice.CHOICE_THIS, + ) + + schema = graphene.Schema(query=Query) + + result = schema.execute( + """ + query { + reporter { + typedChoice + classChoice + callableChoice + } + } + """ + ) + assert not result.errors + assert result.data == { + "reporter": { + "typedChoice": "A_1", + "classChoice": "A_2", + "callableChoice": "A_1", + }, + } diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index 93cbd9f0..00921129 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -40,6 +40,9 @@ class Meta: "email", "pets", "a_choice", + "typed_choice", + "class_choice", + "callable_choice", "fans", "reporter_type", ] diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 72514d23..c63bf2dc 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -77,6 +77,9 @@ def test_django_objecttype_map_correct_fields(): "email", "pets", "a_choice", + "typed_choice", + "class_choice", + "callable_choice", "fans", "reporter_type", ] @@ -186,6 +189,9 @@ def test_schema_representation(): email: String! pets: [Reporter!]! aChoice: TestsReporterAChoiceChoices + typedChoice: TestsReporterTypedChoiceChoices + classChoice: TestsReporterClassChoiceChoices + callableChoice: TestsReporterCallableChoiceChoices reporterType: TestsReporterReporterTypeChoices articles(offset: Int, before: String, after: String, first: Int, last: Int): ArticleConnection! } @@ -199,6 +205,33 @@ def test_schema_representation(): A_2 } + \"""An enumeration.\""" + enum TestsReporterTypedChoiceChoices { + \"""Choice This\""" + A_1 + + \"""Choice That\""" + A_2 + } + + \"""An enumeration.\""" + enum TestsReporterClassChoiceChoices { + \"""Choice This\""" + A_1 + + \"""Choice That\""" + A_2 + } + + \"""An enumeration.\""" + enum TestsReporterCallableChoiceChoices { + \"""Choice This\""" + A_1 + + \"""Choice That\""" + A_2 + } + \"""An enumeration.\""" enum TestsReporterReporterTypeChoices { \"""Regular\"""