Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
BoxyUwU committed Sep 12, 2024
1 parent ba8eab4 commit 8ddba74
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 2 deletions.
26 changes: 26 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,30 @@ def uuid_schema(
)


class NestedSchema(TypedDict, total=False):
type: Required[Literal['nested']]
cls: Required[Type[Any]]
# Should return `(CoreSchema, SchemaValidator, SchemaSerializer)` but this requires a forward ref
get_info: Required[Callable[[], Any]]
metadata: Any
serialization: SerSchema

def nested_schema(
*,
cls: Type[Any],
get_info: Callable[[], Any],
metadata: Any = None,
serialization: SerSchema | None = None
) -> NestedSchema:
return _dict_not_none(
type='nested',
cls=cls,
get_info=get_info,
metadata=metadata,
serialization=serialization
)


class IncExSeqSerSchema(TypedDict, total=False):
type: Required[Literal['include-exclude-sequence']]
include: Set[int]
Expand Down Expand Up @@ -3866,6 +3890,7 @@ def definition_reference_schema(
DefinitionReferenceSchema,
UuidSchema,
ComplexSchema,
NestedSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3922,6 +3947,7 @@ def definition_reference_schema(
'definition-ref',
'uuid',
'complex',
'nested',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down
21 changes: 20 additions & 1 deletion src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
Expand Down Expand Up @@ -58,6 +58,25 @@ impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
}
}

impl<T: PyGcTraverse, E> PyGcTraverse for Result<T, E> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
Ok(v) => T::py_gc_traverse(v, visit),
// FIXME(BoxyUwU): Lol
Err(_) => Ok(()),
}
}
}

impl<T: PyGcTraverse> PyGcTraverse for OnceLock<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self.get() {
Some(item) => T::py_gc_traverse(item, visit),
None => Ok(()),
}
}
}

/// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations
macro_rules! impl_py_gc_traverse {
($name:ty { }) => {
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ combined_serializer! {
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
Complex: super::type_serializers::complex::ComplexSerializer;
Nested: super::type_serializers::nested::NestedSerializer;
}
}

Expand Down Expand Up @@ -254,6 +255,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Nested(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/type_serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod json_or_python;
pub mod list;
pub mod literal;
pub mod model;
pub mod nested;
pub mod nullable;
pub mod other;
pub mod set_frozenset;
Expand Down
135 changes: 135 additions & 0 deletions src/serializers/type_serializers/nested.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::{borrow::Cow, sync::OnceLock};

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

use crate::{
definitions::DefinitionsBuilder,
serializers::{
shared::{BuildSerializer, TypeSerializer},
CombinedSerializer, Extra,
},
SchemaSerializer,
};

#[derive(Debug)]
pub struct NestedSerializer {
model: Py<PyType>,
name: String,
get_serializer: Py<PyAny>,
serializer: OnceLock<PyResult<Py<SchemaSerializer>>>,
}

impl_py_gc_traverse!(NestedSerializer {
model,
get_serializer,
serializer
});

impl BuildSerializer for NestedSerializer {
const EXPECTED_TYPE: &'static str = "nested";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();

let get_serializer = schema
.get_item(intern!(py, "get_info"))?
.expect("Invalid core schema for `nested` type, no `get_info`")
.unbind();

let model = schema
.get_item(intern!(py, "cls"))?
.expect("Invalid core schema for `nested` type, no `model`")
.downcast::<PyType>()
.expect("Invalid core schema for `nested` type, not a `PyType`")
.clone();

let name = model.getattr(intern!(py, "__name__"))?.extract()?;

Ok(CombinedSerializer::Nested(NestedSerializer {
model: model.clone().unbind(),
name,
get_serializer,
serializer: OnceLock::new(),
}))
}
}

impl NestedSerializer {
fn nested_serializer<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaSerializer>> {
self.serializer
.get_or_init(|| {
Ok(self
.get_serializer
.bind(py)
.call((), None)?
.downcast::<PyTuple>()?
.get_item(2)?
.downcast::<SchemaSerializer>()?
.clone()
.unbind())
})
.as_ref()
.map_err(|e| e.clone_ref(py))
}
}

impl TypeSerializer for NestedSerializer {
fn to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
mut extra: &Extra,
) -> PyResult<PyObject> {
let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?;

self.nested_serializer(value.py())?
.bind(value.py())
.get()
.serializer
.to_python(value, include, exclude, guard.state())
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.nested_serializer(key.py())?
.bind(key.py())
.get()
.serializer
.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
use super::py_err_se_err;

let mut guard = extra
.recursion_guard(value, self.model.as_ptr() as usize)
.map_err(py_err_se_err)?;

self.nested_serializer(value.py())
// FIXME(BoxyUwU): Don't unwrap this
.unwrap()
.bind(value.py())
.get()
.serializer
.serde_serialize(value, serializer, include, exclude, guard.state())
}

fn get_name(&self) -> &str {
&self.name
}
}
4 changes: 4 additions & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod list;
mod literal;
mod model;
mod model_fields;
mod nested;
mod none;
mod nullable;
mod set;
Expand Down Expand Up @@ -584,6 +585,7 @@ pub fn build_validator(
definitions::DefinitionRefValidator,
definitions::DefinitionsValidatorBuilder,
complex::ComplexValidator,
nested::NestedValidator,
)
}

Expand Down Expand Up @@ -738,6 +740,8 @@ pub enum CombinedValidator {
// input dependent
JsonOrPython(json_or_python::JsonOrPython),
Complex(complex::ComplexValidator),
// Schema for reusing an existing validator
Nested(nested::NestedValidator),
}

/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {

let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
let name = class.getattr(intern!(py, "__name__"))?.extract()?;

Ok(Self {
Expand Down
115 changes: 115 additions & 0 deletions src/validators/nested.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::sync::OnceLock;

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyTupleMethods, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

use crate::{
definitions::DefinitionsBuilder,
errors::{ErrorTypeDefaults, ValError, ValResult},
input::Input,
recursion_guard::RecursionGuard,
};

use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};

#[derive(Debug)]
pub struct NestedValidator {
cls: Py<PyType>,
name: String,
get_validator: Py<PyAny>,
validator: OnceLock<PyResult<Py<SchemaValidator>>>,
}

impl_py_gc_traverse!(NestedValidator {
cls,
get_validator,
validator
});

impl BuildValidator for NestedValidator {
const EXPECTED_TYPE: &'static str = "nested";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
) -> PyResult<super::CombinedValidator> {
let py = schema.py();

let get_validator = schema.get_item(intern!(py, "get_info"))?.unwrap().unbind();

let cls = schema
.get_item(intern!(py, "cls"))?
.unwrap()
.downcast::<PyType>()?
.clone();

let name = cls.getattr(intern!(py, "__name__"))?.extract()?;

Ok(CombinedValidator::Nested(NestedValidator {
cls: cls.clone().unbind(),
name,
get_validator: get_validator,
validator: OnceLock::new(),
}))
}
}

impl NestedValidator {
fn nested_validator<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaValidator>> {
self.validator
.get_or_init(|| {
Ok(self
.get_validator
.bind(py)
.call((), None)?
.downcast::<PyTuple>()?
.get_item(1)?
.downcast::<SchemaValidator>()?
.clone()
.unbind())
})
.as_ref()
.map_err(|e| e.clone_ref(py))
}
}

impl Validator for NestedValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let Some(id) = input.as_python().map(py_identity) else {
return self
.nested_validator(py)?
.bind(py)
.get()
.validator
.validate(py, input, state);
};

// Python objects can be cyclic, so need recursion guard
let Ok(mut guard) = RecursionGuard::new(state, id, self.cls.as_ptr() as usize) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
};

self.nested_validator(py)?
.bind(py)
.get()
.validator
.validate(py, input, guard.state())
}

fn get_name(&self) -> &str {
&self.name
}
}

fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
obj.as_ptr() as usize
}

0 comments on commit 8ddba74

Please sign in to comment.