Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
BoxyUwU committed Aug 22, 2024
1 parent bb67044 commit 0961627
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 1 deletion.
20 changes: 20 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,24 @@ def uuid_schema(
)


class NestedModelSchema(TypedDict, total=False):
type: Required[Literal['nested-model']]
model: Required[Type[Any]]
metadata: Any


def nested_model_schema(
*,
model: Type[Any],
metadata: Any = None,
) -> NestedModelSchema:
return _dict_not_none(
type='nested-model',
model=model,
metadata=metadata,
)


class IncExSeqSerSchema(TypedDict, total=False):
type: Required[Literal['include-exclude-sequence']]
include: Set[int]
Expand Down Expand Up @@ -3796,6 +3814,7 @@ def definition_reference_schema(
DefinitionsSchema,
DefinitionReferenceSchema,
UuidSchema,
NestedModelSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3851,6 +3870,7 @@ def definition_reference_schema(
'definitions',
'definition-ref',
'uuid',
'nested-model',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ combined_serializer! {
Enum: super::type_serializers::enum_::EnumSerializer;
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
NestedModel: super::type_serializers::nested_model::NestedModelSerializer;
}
}

Expand Down Expand Up @@ -251,6 +252,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Recursive(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::NestedModel(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/type_serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod json_or_python;
pub mod list;
pub mod literal;
pub mod model;
pub mod nested_model;
pub mod nullable;
pub mod other;
pub mod set_frozenset;
Expand Down
108 changes: 108 additions & 0 deletions src/serializers/type_serializers/nested_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::borrow::Cow;

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

use crate::{
definitions::DefinitionsBuilder,
serializers::{
shared::{BuildSerializer, TypeSerializer},
CombinedSerializer, Extra,
},
SchemaSerializer,
};

#[derive(Debug, Clone)]
pub struct NestedModelSerializer {
model: Py<PyType>,
name: String,
}

impl_py_gc_traverse!(NestedModelSerializer { model });

impl BuildSerializer for NestedModelSerializer {
const EXPECTED_TYPE: &'static str = "nested-model";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();
let model = schema
.get_item(intern!(py, "model"))?
.expect("Invalid core schema for `nested-model` type")
.downcast::<PyType>()
.expect("Invalid core schema for `nested-model` type")
.clone();

let name = model.getattr(intern!(py, "__name__"))?.extract()?;

Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
model: model.clone().unbind(),
name,
}))
}
}

impl NestedModelSerializer {
fn nested_serializer<'py>(&self, py: Python<'py>) -> Bound<'py, SchemaSerializer> {
self.model
.bind(py)
.call_method(intern!(py, "model_rebuild"), (), None)
.unwrap();

self.model
.getattr(py, intern!(py, "__pydantic_serializer__"))
.unwrap()
.downcast_bound::<SchemaSerializer>(py)
.unwrap()
.clone()

// crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
// .downcast_bound::<SchemaSerializer>(py)
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
// .expect("Cached validator was not a `SchemaSerializer`")
// .clone()
}
}

impl TypeSerializer for NestedModelSerializer {
fn to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> PyResult<PyObject> {
self.nested_serializer(value.py())
.get()
.serializer
.to_python(value, include, exclude, extra)
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.nested_serializer(key.py()).get().serializer.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
self.nested_serializer(value.py())
.get()
.serializer
.serde_serialize(value, serializer, include, exclude, extra)
}

fn get_name(&self) -> &str {
&self.name
}
}
4 changes: 4 additions & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ mod list;
mod literal;
mod model;
mod model_fields;
mod nested_model;
mod none;
mod nullable;
mod set;
Expand Down Expand Up @@ -582,6 +583,7 @@ pub fn build_validator(
// recursive (self-referencing) models
definitions::DefinitionRefValidator,
definitions::DefinitionsValidatorBuilder,
nested_model::NestedModelValidator,
)
}

Expand Down Expand Up @@ -735,6 +737,8 @@ pub enum CombinedValidator {
DefinitionRef(definitions::DefinitionRefValidator),
// input dependent
JsonOrPython(json_or_python::JsonOrPython),
// Schema for a model inside of another schema
NestedModel(nested_model::NestedModelValidator),
}

/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {

let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
let name = class.getattr(intern!(py, "__name__"))?.extract()?;

Ok(Self {
Expand Down
76 changes: 76 additions & 0 deletions src/validators/nested_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyType},
Bound, Py, PyObject, PyResult, Python,
};

use crate::{definitions::DefinitionsBuilder, errors::ValResult, input::Input};

use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct NestedModelValidator {
model: Py<PyType>,
name: String,
}

impl_py_gc_traverse!(NestedModelValidator { model });

impl BuildValidator for NestedModelValidator {
const EXPECTED_TYPE: &'static str = "nested-model";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
) -> PyResult<super::CombinedValidator> {
let py = schema.py();
let model = schema
.get_item(intern!(py, "model"))?
.expect("Invalid core schema for `nested-model` type")
.downcast::<PyType>()
.expect("Invalid core schema for `nested-model` type")
.clone();

let name = model.getattr(intern!(py, "__name__"))?.extract()?;

Ok(CombinedValidator::NestedModel(NestedModelValidator {
model: model.clone().unbind(),
name,
}))
}
}

impl Validator for NestedModelValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
self.model
.bind(py)
.call_method(intern!(py, "model_rebuild"), (), None)
.unwrap();

let validator = self
.model
.getattr(py, intern!(py, "__pydantic_validator__"))
.unwrap()
.downcast_bound::<SchemaValidator>(py)
.unwrap()
.clone();

// let validator = crate::schema_cache::retrieve_schema(py, self.model.as_any().clone())
// .downcast_bound::<SchemaValidator>(py)
// // FIXME: This actually will always trigger as we cache a `CoreSchema` lol
// .expect("Cached validator was not a `SchemaValidator`")
// .clone();

validator.get().validator.validate(py, input, state)
}

fn get_name(&self) -> &str {
&self.name
}
}

0 comments on commit 0961627

Please sign in to comment.