diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index afaa8bce0..f9ea86218 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -570,16 +570,18 @@ def _test_collection(self, value): if many and not utils.is_collection(value): raise self.make_error("type", input=value, type=value.__class__.__name__) - def _load(self, value, data, partial=None): + def _load(self, value, data, partial=None, unknown=None): try: - valid_data = self.schema.load(value, unknown=self.unknown, partial=partial) + valid_data = self.schema.load( + value, unknown=unknown or self.unknown, partial=partial, + ) except ValidationError as error: raise ValidationError( error.messages, valid_data=error.valid_data ) from error return valid_data - def _deserialize(self, value, attr, data, partial=None, **kwargs): + def _deserialize(self, value, attr, data, partial=None, unknown=None, **kwargs): """Same as :meth:`Field._deserialize` with additional ``partial`` argument. :param bool|tuple partial: For nested schemas, the ``partial`` @@ -589,7 +591,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs): Add ``partial`` parameter. """ self._test_collection(value) - return self._load(value, data, partial=partial) + return self._load(value, data, partial=partial, unknown=unknown) class Pluck(Nested): diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 0d4ef106b..606999e3f 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -658,6 +658,13 @@ def _deserialize( d_kwargs["partial"] = sub_partial else: d_kwargs["partial"] = partial + + try: + if self.context["propagate_unknown_to_nested"]: + d_kwargs["unknown"] = unknown + except KeyError: + pass + getter = lambda val: field_obj.deserialize( val, field_name, data, **d_kwargs ) @@ -835,6 +842,7 @@ def _do_load( error_store = ErrorStore() errors = {} # type: typing.Dict[str, typing.List[str]] many = self.many if many is None else bool(many) + self.context["propagate_unknown_to_nested"] = unknown is not None unknown = unknown or self.unknown if partial is None: partial = self.partial diff --git a/tests/test_fields.py b/tests/test_fields.py index 8671be959..a2f2f5b12 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -278,6 +278,70 @@ class MySchema(Schema): MySchema().load({"nested": {"x": 1}}) +class TestNestedFieldPropagatesUnknown: + class SpamSchema(Schema): + meat = fields.String() + + class CanSchema(Schema): + spam = fields.Nested("SpamSchema") + + class ShelfSchema(Schema): + can = fields.Nested("CanSchema") + + @pytest.fixture + def data_nested_unknown(self): + return { + "spam": {"meat": "pork", "add-on": "eggs"}, + } + + @pytest.fixture + def multi_nested_data_with_unknown(self, data_nested_unknown): + return { + "can": data_nested_unknown, + "box": {"foo": "bar"}, + } + + @pytest.mark.parametrize("schema_kw", ({}, {"unknown": INCLUDE})) + def test_raises_when_unknown_passed_to_first_level_nested( + self, schema_kw, data_nested_unknown, + ): + with pytest.raises(ValidationError) as exc_info: + self.CanSchema(**schema_kw).load(data_nested_unknown) + assert exc_info.value.messages == {"spam": {"add-on": ["Unknown field."]}} + + @pytest.mark.parametrize( + "load_kw,expected_data", + ( + ({"unknown": INCLUDE}, {"spam": {"meat": "pork", "add-on": "eggs"}}), + ({"unknown": EXCLUDE}, {"spam": {"meat": "pork"}}), + ), + ) + def test_processes_when_unknown_stated_directly( + self, load_kw, data_nested_unknown, expected_data, + ): + data = self.CanSchema().load(data_nested_unknown, **load_kw) + assert data == expected_data + + @pytest.mark.parametrize( + "load_kw,expected_data", + ( + ( + {"unknown": INCLUDE}, + { + "can": {"spam": {"meat": "pork", "add-on": "eggs"}}, + "box": {"foo": "bar"}, + }, + ), + ({"unknown": EXCLUDE}, {"can": {"spam": {"meat": "pork"}}}), + ), + ) + def test_propagates_unknown_to_multi_nested_fields( + self, load_kw, expected_data, multi_nested_data_with_unknown, + ): + data = self.ShelfSchema().load(multi_nested_data_with_unknown, **load_kw) + assert data == expected_data + + class TestListNested: @pytest.mark.parametrize("param", ("only", "exclude", "dump_only", "load_only")) def test_list_nested_only_exclude_dump_only_load_only_propagated_to_nested(