From c2c7dbb2f88ce8f0ef6d48a61b93866c9926349a Mon Sep 17 00:00:00 2001 From: Bendeguz Csirmaz Date: Fri, 20 Sep 2024 10:03:47 +0200 Subject: [PATCH] Refs #373 -- Updated TupleIsNull lookup to check if any is NULL rather than all. Regression in 1eac690d25dd49088256954d4046813daa37dc95. --- django/db/models/fields/tuple_lookups.py | 27 +++++++++++++++--------- tests/foreign_object/models/person.py | 2 +- tests/foreign_object/tests.py | 25 ++++++++++++++++++---- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index 04c53944dc11..cdb3b4720989 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -57,18 +57,25 @@ def as_oracle(self, compiler, connection): return root.as_sql(compiler, connection) -class TupleIsNull(IsNull): +class TupleIsNull(TupleLookupMixin, IsNull): + def get_prep_lookup(self): + rhs = self.rhs + if isinstance(rhs, (tuple, list)) and len(rhs) == 1: + rhs = rhs[0] + if isinstance(rhs, bool): + return rhs + raise ValueError( + "The QuerySet value for an isnull lookup must be True or False." + ) + def as_sql(self, compiler, connection): # e.g.: (a, b, c) is None as SQL: - # WHERE a IS NULL AND b IS NULL AND c IS NULL - vals = self.rhs - if isinstance(vals, bool): - vals = [vals] * len(self.lhs) - - cols = self.lhs.get_cols() - lookups = [IsNull(col, val) for col, val in zip(cols, vals)] - root = WhereNode(lookups, connector=AND) - + # WHERE a IS NULL OR b IS NULL OR c IS NULL + # e.g.: (a, b, c) is not None as SQL: + # WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL + rhs = self.rhs + lookups = [IsNull(col, rhs) for col in self.lhs] + root = WhereNode(lookups, connector=OR if rhs else AND) return root.as_sql(compiler, connection) diff --git a/tests/foreign_object/models/person.py b/tests/foreign_object/models/person.py index 33063e728abd..f0848e6c3e57 100644 --- a/tests/foreign_object/models/person.py +++ b/tests/foreign_object/models/person.py @@ -49,7 +49,7 @@ def __str__(self): class Membership(models.Model): # Table Column Fields - membership_country = models.ForeignKey(Country, models.CASCADE) + membership_country = models.ForeignKey(Country, models.CASCADE, null=True) date_joined = models.DateTimeField(default=datetime.datetime.now) invite_reason = models.CharField(max_length=64, null=True) person_id = models.IntegerField() diff --git a/tests/foreign_object/tests.py b/tests/foreign_object/tests.py index 89ed85b658d7..e288ecd7d4d5 100644 --- a/tests/foreign_object/tests.py +++ b/tests/foreign_object/tests.py @@ -516,18 +516,35 @@ def test_batch_create_foreign_object(self): def test_isnull_lookup(self): m1 = Membership.objects.create( - membership_country=self.usa, person=self.bob, group_id=None + person_id=self.bob.id, + membership_country_id=self.usa.id, + group_id=None, ) m2 = Membership.objects.create( - membership_country=self.usa, person=self.bob, group=self.cia + person_id=self.jim.id, + membership_country_id=None, + group_id=self.cia.id, + ) + m3 = Membership.objects.create( + person_id=self.jane.id, + membership_country_id=None, + group_id=None, + ) + m4 = Membership.objects.create( + person_id=self.george.id, + membership_country_id=self.soviet_union.id, + group_id=self.kgb.id, ) + for member in [m1, m2, m3]: + with self.assertRaises(Membership.group.RelatedObjectDoesNotExist): + getattr(member, "group") self.assertSequenceEqual( Membership.objects.filter(group__isnull=True), - [m1], + [m1, m2, m3], ) self.assertSequenceEqual( Membership.objects.filter(group__isnull=False), - [m2], + [m4], )