Skip to content

Commit

Permalink
Modify Nested field to handle the combination of many=True and allow_…
Browse files Browse the repository at this point in the history
  • Loading branch information
Meallia committed Jan 31, 2020
1 parent 709fcbc commit f7b8cc7
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 5 deletions.
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,4 @@ Contributors (chronological)
- Taneli Hukkinen `@hukkinj1 <https://github.com/hukkinj1>`_
- `@Reskov <https://github.com/Reskov>`_
- Albert Tugushev `@atugushev <https://github.com/atugushev>`_
- Jonathan Gayvallet `@Meallia <https://github.com/Meallia>`_
23 changes: 20 additions & 3 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,13 @@ def _serialize(self, nested_obj, attr, obj, many=False, **kwargs):
if nested_obj is None:
return None
many = schema.many or self.many or many
if many and self.allow_none:
ret = schema.dump(
[obj for obj in nested_obj if obj is not None], many=self.many or many
)
for none_index in (i for i, obj in enumerate(nested_obj) if obj is None):
ret.insert(none_index, None)
return ret
return schema.dump(nested_obj, many=self.many or many)

def _test_collection(self, value, many=False):
Expand All @@ -587,9 +594,19 @@ def _test_collection(self, value, many=False):
def _load(self, value, data, partial=None, many=False):
many = self.schema.many or self.many or many
try:
valid_data = self.schema.load(
value, unknown=self.unknown, partial=partial, many=many
)
if many and self.allow_none:
valid_data = self.schema.load(
[obj for obj in value if obj is not None],
unknown=self.unknown,
partial=partial,
many=many,
)
for none_index in (i for i, obj in enumerate(value) if obj is None):
valid_data.insert(none_index, None)
else:
valid_data = self.schema.load(
value, unknown=self.unknown, partial=partial, many=many
)
except ValidationError as error:
raise ValidationError(
error.messages, valid_data=error.valid_data
Expand Down
22 changes: 20 additions & 2 deletions tests/test_deserialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TestDeserializingNone:
@pytest.mark.parametrize("FieldClass", ALL_FIELDS)
def test_fields_allow_none_deserialize_to_none(self, FieldClass):
field = FieldClass(allow_none=True)
field.deserialize(None) is None
assert field.deserialize(None) is None

# https://github.com/marshmallow-code/marshmallow/issues/111
@pytest.mark.parametrize("FieldClass", ALL_FIELDS)
Expand All @@ -28,7 +28,7 @@ def test_fields_dont_allow_none_by_default(self, FieldClass):
def test_allow_none_is_true_if_missing_is_true(self):
field = fields.Field(missing=None)
assert field.allow_none is True
field.deserialize(None) is None
assert field.deserialize(None) is None

def test_list_field_deserialize_none_to_none(self):
field = fields.List(fields.String(allow_none=True), allow_none=True)
Expand All @@ -38,6 +38,24 @@ def test_tuple_field_deserialize_none_to_none(self):
field = fields.Tuple([fields.String()], allow_none=True)
assert field.deserialize(None) is None

def test_list_of_nested_nullable_deserialize_none_to_none(self):
field = fields.List(fields.Nested(Schema(), allow_none=True))
assert field.deserialize([None, {}]) == [None, {}]

def test_nested_multiple_nullable_deserialize_none_to_none(self):
field = fields.Nested(Schema(), allow_none=True, many=True)
assert field.deserialize([None, {}]) == [None, {}]

def test_list_of_nested_non_nullable_deserialize_none_to_validation_error(self):
field = fields.List(fields.Nested(Schema(), allow_none=False))
with pytest.raises(ValidationError):
field.deserialize([None])

def test_nested_multiple_non_nullable_deserialize_none_to_validation_error(self):
field = fields.Nested(Schema(), allow_none=True, many=False)
with pytest.raises(ValidationError):
field.deserialize([None])


class TestFieldDeserialization:
def test_float_field_deserialization(self):
Expand Down
12 changes: 12 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,18 @@ def test_all_fields_serialize_none_to_none(self, FieldClass):
res = field.serialize("foo", {"foo": None})
assert res is None

def test_list_of_nested_nullable_serialize_none_to_none(self):
res = fields.List(fields.Nested(Schema(), allow_none=True)).serialize(
"foo", {"foo": [None, {}]}
)
assert res == [None, {}]

def test_nested_multiple_nullable_serialize_none_to_none(self):
res = fields.Nested(Schema(), allow_none=True, many=True).serialize(
"foo", {"foo": [None, {}]}
)
assert res == [None, {}]


class TestSchemaSerialization:
def test_serialize_with_missing_param_value(self):
Expand Down

0 comments on commit f7b8cc7

Please sign in to comment.