Skip to content

Commit

Permalink
Fixed 35561 -- Made *args and **kwargs parsing more strict in Model.s…
Browse files Browse the repository at this point in the history
…ave()/asave().
  • Loading branch information
nessita authored Jun 26, 2024
1 parent 88966bc commit e56a32b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 37 deletions.
88 changes: 51 additions & 37 deletions django/db/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,43 @@ def serializable_value(self, field_name):
return getattr(self, field_name)
return getattr(self, field.attname)

# RemovedInDjango60Warning: When the deprecation ends, remove completely.
def _parse_params(self, *args, method_name, **kwargs):
defaults = {
"force_insert": False,
"force_update": False,
"using": None,
"update_fields": None,
}

warnings.warn(
f"Passing positional arguments to {method_name}() is deprecated",
RemovedInDjango60Warning,
stacklevel=2,
)
total_len_args = len(args) + 1 # include self
max_len_args = len(defaults) + 1
if total_len_args > max_len_args:
# Recreate the proper TypeError message from Python.
raise TypeError(
f"Model.{method_name}() takes from 1 to {max_len_args} positional "
f"arguments but {total_len_args} were given"
)

def get_param(param_name, param_value, arg_index):
if arg_index < len(args):
if param_value is not defaults[param_name]:
# Recreate the proper TypeError message from Python.
raise TypeError(
f"Model.{method_name}() got multiple values for argument "
f"'{param_name}'"
)
return args[arg_index]

return param_value

return [get_param(k, v, i) for i, (k, v) in enumerate(kwargs.items())]

# RemovedInDjango60Warning: When the deprecation ends, replace with:
# def save(
# self, *, force_insert=False, force_update=False, using=None, update_fields=None,
Expand All @@ -798,25 +835,14 @@ def save(
"""
# RemovedInDjango60Warning.
if args:
warnings.warn(
"Passing positional arguments to save() is deprecated",
RemovedInDjango60Warning,
stacklevel=2,
force_insert, force_update, using, update_fields = self._parse_params(
*args,
method_name="save",
force_insert=force_insert,
force_update=force_update,
using=using,
update_fields=update_fields,
)
total_len_args = len(args) + 1 # include self
if total_len_args > 5:
# Recreate the proper TypeError message from Python.
raise TypeError(
"Model.save() takes from 1 to 5 positional arguments but "
f"{total_len_args} were given"
)
force_insert = args[0]
try:
force_update = args[1]
using = args[2]
update_fields = args[3]
except IndexError:
pass

self._prepare_related_fields_for_save(operation_name="save")

Expand Down Expand Up @@ -885,26 +911,14 @@ async def asave(
):
# RemovedInDjango60Warning.
if args:
warnings.warn(
"Passing positional arguments to asave() is deprecated",
RemovedInDjango60Warning,
stacklevel=2,
force_insert, force_update, using, update_fields = self._parse_params(
*args,
method_name="asave",
force_insert=force_insert,
force_update=force_update,
using=using,
update_fields=update_fields,
)
total_len_args = len(args) + 1 # include self
if total_len_args > 5:
# Recreate the proper TypeError message from Python.
raise TypeError(
"Model.asave() takes from 1 to 5 positional arguments but "
f"{total_len_args} were given"
)
force_insert = args[0]
try:
force_update = args[1]
using = args[2]
update_fields = args[3]
except IndexError:
pass

return await sync_to_async(self.save)(
force_insert=force_insert,
force_update=force_update,
Expand Down
34 changes: 34 additions & 0 deletions tests/basic/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,23 @@ def test_save_too_many_positional_arguments(self):
):
a.save(False, False, None, None, None)

def test_save_conflicting_positional_and_named_arguments(self):
a = Article()
cases = [
("force_insert", True, [42]),
("force_update", None, [42, 41]),
("using", "some-db", [42, 41, 40]),
("update_fields", ["foo"], [42, 41, 40, 39]),
]
for param_name, param_value, args in cases:
with self.subTest(param_name=param_name):
msg = f"Model.save() got multiple values for argument '{param_name}'"
with (
self.assertWarns(RemovedInDjango60Warning),
self.assertRaisesMessage(TypeError, msg),
):
a.save(*args, **{param_name: param_value})

async def test_asave_deprecation(self):
a = Article(headline="original", pub_date=datetime(2014, 5, 16))
msg = "Passing positional arguments to asave() is deprecated"
Expand Down Expand Up @@ -275,6 +292,23 @@ async def test_asave_too_many_positional_arguments(self):
):
await a.asave(False, False, None, None, None)

async def test_asave_conflicting_positional_and_named_arguments(self):
a = Article()
cases = [
("force_insert", True, [42]),
("force_update", None, [42, 41]),
("using", "some-db", [42, 41, 40]),
("update_fields", ["foo"], [42, 41, 40, 39]),
]
for param_name, param_value, args in cases:
with self.subTest(param_name=param_name):
msg = f"Model.asave() got multiple values for argument '{param_name}'"
with (
self.assertWarns(RemovedInDjango60Warning),
self.assertRaisesMessage(TypeError, msg),
):
await a.asave(*args, **{param_name: param_value})

@ignore_warnings(category=RemovedInDjango60Warning)
def test_save_positional_arguments(self):
a = Article.objects.create(headline="original", pub_date=datetime(2014, 5, 16))
Expand Down

0 comments on commit e56a32b

Please sign in to comment.