diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index 4ac24bd6c..311e2af42 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -1384,6 +1384,24 @@ def uuid_schema( ) +class NestedModelSchema(TypedDict, total=False): + type: Required[Literal['nested-model']] + model: Required[Type[Any]] + metadata: Any + + +def nested_model_schema( + *, + model: Type[Any], + metadata: Any = None, +) -> NestedModelSchema: + return _dict_not_none( + type='nested-model', + model=model, + metadata=metadata, + ) + + class IncExSeqSerSchema(TypedDict, total=False): type: Required[Literal['include-exclude-sequence']] include: Set[int] @@ -3796,6 +3814,7 @@ def definition_reference_schema( DefinitionsSchema, DefinitionReferenceSchema, UuidSchema, + NestedModelSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -3851,6 +3870,7 @@ def definition_reference_schema( 'definitions', 'definition-ref', 'uuid', + 'nested-model', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index e7930512c..35fd1f9d1 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -142,6 +142,7 @@ combined_serializer! { Enum: super::type_serializers::enum_::EnumSerializer; Recursive: super::type_serializers::definitions::DefinitionRefSerializer; Tuple: super::type_serializers::tuple::TupleSerializer; + NestedModel: super::type_serializers::nested_model::NestedModelSerializer; } } @@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer { CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit), CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit), + CombinedSerializer::NestedModel(inner) => inner.py_gc_traverse(visit), } } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index da36f0bc1..285af0958 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -15,6 +15,7 @@ pub mod json_or_python; pub mod list; pub mod literal; pub mod model; +pub mod nested_model; pub mod nullable; pub mod other; pub mod set_frozenset; diff --git a/src/serializers/type_serializers/nested_model.rs b/src/serializers/type_serializers/nested_model.rs new file mode 100644 index 000000000..dd0652db2 --- /dev/null +++ b/src/serializers/type_serializers/nested_model.rs @@ -0,0 +1,108 @@ +use std::borrow::Cow; + +use pyo3::{ + intern, + types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, + Bound, Py, PyAny, PyObject, PyResult, Python, +}; + +use crate::{ + definitions::DefinitionsBuilder, + serializers::{ + shared::{BuildSerializer, TypeSerializer}, + CombinedSerializer, Extra, + }, + SchemaSerializer, +}; + +#[derive(Debug, Clone)] +pub struct NestedModelSerializer { + model: Py, + name: String, +} + +impl_py_gc_traverse!(NestedModelSerializer { model }); + +impl BuildSerializer for NestedModelSerializer { + const EXPECTED_TYPE: &'static str = "nested-model"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let model = schema + .get_item(intern!(py, "model"))? + .expect("Invalid core schema for `nested-model` type") + .downcast::() + .expect("Invalid core schema for `nested-model` type") + .clone(); + + let name = model.getattr(intern!(py, "__name__"))?.extract()?; + + Ok(CombinedSerializer::NestedModel(NestedModelSerializer { + model: model.clone().unbind(), + name, + })) + } +} + +impl NestedModelSerializer { + fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> { + self.model + .bind(py) + .call_method(intern!(py, "model_rebuild"), (), None) + .unwrap(); + + self.model + .getattr(py, intern!(py, "__pydantic_serializer__")) + .unwrap() + .downcast_bound::(py) + .unwrap() + .clone() + + // crate::schema_cache::retrieve_schema(py, self.model.as_any().clone()) + // .downcast_bound::(py) + // // FIXME: This actually will always trigger as we cache a `CoreSchema` lol + // .expect("Cached validator was not a `SchemaSerializer`") + // .clone() + } +} + +impl TypeSerializer for NestedModelSerializer { + fn to_python( + &self, + value: &Bound<'_, PyAny>, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> PyResult { + self.nested_serializer(value.py()) + .get() + .serializer + .to_python(value, include, exclude, extra) + } + + fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult> { + self.nested_serializer(key.py()).get().serializer.json_key(key, extra) + } + + fn serde_serialize( + &self, + value: &Bound<'_, PyAny>, + serializer: S, + include: Option<&Bound<'_, PyAny>>, + exclude: Option<&Bound<'_, PyAny>>, + extra: &Extra, + ) -> Result { + self.nested_serializer(value.py()) + .get() + .serializer + .serde_serialize(value, serializer, include, exclude, extra) + } + + fn get_name(&self) -> &str { + &self.name + } +} diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 18c947313..71870cbd0 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -48,6 +48,7 @@ mod list; mod literal; mod model; mod model_fields; +mod nested_model; mod none; mod nullable; mod set; @@ -582,6 +583,7 @@ pub fn build_validator( // recursive (self-referencing) models definitions::DefinitionRefValidator, definitions::DefinitionsValidatorBuilder, + nested_model::NestedModelValidator, ) } @@ -735,6 +737,8 @@ pub enum CombinedValidator { DefinitionRef(definitions::DefinitionRefValidator), // input dependent JsonOrPython(json_or_python::JsonOrPython), + // Schema for a model inside of another schema + NestedModel(nested_model::NestedModelValidator), } /// This trait must be implemented by all validators, it allows various validators to be accessed consistently, diff --git a/src/validators/model.rs b/src/validators/model.rs index 2c0cef6fd..dad29e3e4 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator { let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?; let sub_schema = schema.get_as_req(intern!(py, "schema"))?; - let validator = build_validator(&sub_schema, config.as_ref(), definitions)?; + let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?; let name = class.getattr(intern!(py, "__name__"))?.extract()?; Ok(Self { diff --git a/src/validators/nested_model.rs b/src/validators/nested_model.rs new file mode 100644 index 000000000..94137fdb1 --- /dev/null +++ b/src/validators/nested_model.rs @@ -0,0 +1,76 @@ +use pyo3::{ + intern, + types::{PyAnyMethods, PyDict, PyDictMethods, PyType}, + Bound, Py, PyObject, PyResult, Python, +}; + +use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input}; + +use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator}; + +#[derive(Debug, Clone)] +pub struct NestedModelValidator { + model: Py, + name: String, +} + +impl_py_gc_traverse!(NestedModelValidator { model }); + +impl BuildValidator for NestedModelValidator { + const EXPECTED_TYPE: &'static str = "nested-model"; + + fn build( + schema: &Bound<'_, PyDict>, + _config: Option<&Bound<'_, PyDict>>, + _definitions: &mut DefinitionsBuilder, + ) -> PyResult { + let py = schema.py(); + let model = schema + .get_item(intern!(py, "model"))? + .expect("Invalid core schema for `nested-model` type") + .downcast::() + .expect("Invalid core schema for `nested-model` type") + .clone(); + + let name = model.getattr(intern!(py, "__name__"))?.extract()?; + + Ok(CombinedValidator::NestedModel(NestedModelValidator { + model: model.clone().unbind(), + name, + })) + } +} + +impl Validator for NestedModelValidator { + fn validate<'py>( + &self, + py: Python<'py>, + input: &(impl Input<'py> + ?Sized), + state: &mut ValidationState<'_, 'py>, + ) -> ValResult { + self.model + .bind(py) + .call_method(intern!(py, "model_rebuild"), (), None) + .unwrap(); + + let validator = self + .model + .getattr(py, intern!(py, "__pydantic_validator__")) + .unwrap() + .downcast_bound::(py) + .unwrap() + .clone(); + + // let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone()) + // .downcast_bound::(py) + // // FIXME: This actually will always trigger as we cache a `CoreSchema` lol + // .expect("Cached validator was not a `SchemaValidator`") + // .clone(); + + validator.get().validator.validate(py, input, state) + } + + fn get_name(&self) -> &str { + &self.name + } +}