Skip to content

Commit

Permalink
Refs django#373 -- Added additional validations to tuple lookups.
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz authored and sarahboyce committed Oct 14, 2024
1 parent 263f731 commit 97c05a6
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 7 deletions.
36 changes: 31 additions & 5 deletions django/db/models/fields/tuple_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.core.exceptions import EmptyResultSet
from django.db.models import Field
from django.db.models.expressions import Func, Value
from django.db.models.expressions import ColPairs, Func, Value
from django.db.models.lookups import (
Exact,
GreaterThan,
Expand All @@ -28,17 +28,32 @@ def __iter__(self):

class TupleLookupMixin:
def get_prep_lookup(self):
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length()
return self.rhs

def check_rhs_is_tuple_or_list(self):
if not isinstance(self.rhs, (tuple, list)):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
)

def check_rhs_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
if len_lhs != len(self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"must have {len_lhs} elements"
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
)

def get_lhs_str(self):
if isinstance(self.lhs, ColPairs):
return repr(self.lhs.field.name)
else:
names = ", ".join(repr(f.name) for f in self.lhs)
return f"({names})"

def get_prep_lhs(self):
if isinstance(self.lhs, (tuple, list)):
return Tuple(*self.lhs)
Expand Down Expand Up @@ -196,14 +211,25 @@ def as_oracle(self, compiler, connection):

class TupleIn(TupleLookupMixin, In):
def get_prep_lookup(self):
self.check_rhs_is_tuple_or_list()
self.check_rhs_is_collection_of_tuples_or_lists()
self.check_rhs_elements_length_equals_lhs_length()
return super(TupleLookupMixin, self).get_prep_lookup()
return self.rhs # skip checks from mixin

def check_rhs_is_collection_of_tuples_or_lists(self):
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"{self.lookup_name!r} lookup of {lhs_str} "
"must be a collection of tuples or lists"
)

def check_rhs_elements_length_equals_lhs_length(self):
len_lhs = len(self.lhs)
if not all(len_lhs == len(vals) for vals in self.rhs):
lhs_str = self.get_lhs_str()
raise ValueError(
f"'{self.lookup_name}' lookup of '{self.lhs.field.name}' field "
f"{self.lookup_name!r} lookup of {lhs_str} "
f"must have {len_lhs} elements each"
)

Expand Down
110 changes: 108 additions & 2 deletions tests/foreign_object/test_tuple_lookups.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import unittest

from django.db import NotSupportedError, connection
Expand Down Expand Up @@ -129,6 +130,37 @@ def test_in_subquery(self):
(self.contact_1, self.contact_2, self.contact_5),
)

def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
test_cases = (
(1, 2, 3),
((1, 2), (3, 4), None),
)

for rhs in test_cases:
with self.subTest(rhs=rhs):
with self.assertRaisesMessage(
ValueError,
"'in' lookup of ('customer_code', 'company_code') "
"must be a collection of tuples or lists",
):
TupleIn((F("customer_code"), F("company_code")), rhs)

def test_tuple_in_rhs_must_have_2_elements_each(self):
test_cases = (
((),),
((1,),),
((1, 2, 3),),
)

for rhs in test_cases:
with self.subTest(rhs=rhs):
with self.assertRaisesMessage(
ValueError,
"'in' lookup of ('customer_code', 'company_code') "
"must have 2 elements each",
):
TupleIn((F("customer_code"), F("company_code")), rhs)

def test_lt(self):
c1, c2, c3, c4, c5, c6 = (
self.contact_1,
Expand Down Expand Up @@ -358,8 +390,8 @@ def test_isnull_subquery(self):
)

def test_lookup_errors(self):
m_2_elements = "'%s' lookup of 'customer' field must have 2 elements"
m_2_elements_each = "'in' lookup of 'customer' field must have 2 elements each"
m_2_elements = "'%s' lookup of 'customer' must have 2 elements"
m_2_elements_each = "'in' lookup of 'customer' must have 2 elements each"
test_cases = (
({"customer": 1}, m_2_elements % "exact"),
({"customer": (1, 2, 3)}, m_2_elements % "exact"),
Expand All @@ -381,3 +413,77 @@ def test_lookup_errors(self):
self.assertRaisesMessage(ValueError, message),
):
Contact.objects.get(**kwargs)

def test_tuple_lookup_names(self):
test_cases = (
(TupleExact, "exact"),
(TupleGreaterThan, "gt"),
(TupleGreaterThanOrEqual, "gte"),
(TupleLessThan, "lt"),
(TupleLessThanOrEqual, "lte"),
(TupleIn, "in"),
(TupleIsNull, "isnull"),
)

for lookup_class, lookup_name in test_cases:
with self.subTest(lookup_name):
self.assertEqual(lookup_class.lookup_name, lookup_name)

def test_tuple_lookup_rhs_must_be_tuple_or_list(self):
test_cases = itertools.product(
(
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleLessThan,
TupleLessThanOrEqual,
TupleIn,
),
(
0,
1,
None,
True,
False,
{"foo": "bar"},
),
)

for lookup_cls, rhs in test_cases:
lookup_name = lookup_cls.lookup_name
with self.subTest(lookup_name=lookup_name, rhs=rhs):
with self.assertRaisesMessage(
ValueError,
f"'{lookup_name}' lookup of ('customer_code', 'company_code') "
"must be a tuple or a list",
):
lookup_cls((F("customer_code"), F("company_code")), rhs)

def test_tuple_lookup_rhs_must_have_2_elements(self):
test_cases = itertools.product(
(
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleLessThan,
TupleLessThanOrEqual,
),
(
[],
[1],
[1, 2, 3],
(),
(1,),
(1, 2, 3),
),
)

for lookup_cls, rhs in test_cases:
lookup_name = lookup_cls.lookup_name
with self.subTest(lookup_name=lookup_name, rhs=rhs):
with self.assertRaisesMessage(
ValueError,
f"'{lookup_name}' lookup of ('customer_code', 'company_code') "
"must have 2 elements",
):
lookup_cls((F("customer_code"), F("company_code")), rhs)

0 comments on commit 97c05a6

Please sign in to comment.