Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
BoxyUwU committed Aug 28, 2024
1 parent 26e3ec4 commit 40da0ee
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 42 deletions.
10 changes: 10 additions & 0 deletions src/py_gc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
}
}

impl<T: PyGcTraverse, E> PyGcTraverse for Result<T, E> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
Ok(v) => T::py_gc_traverse(v, visit),
// FIXME(BoxyUwU): Lol
Err(_) => Ok(()),
}
}
}

impl<T: PyGcTraverse> PyGcTraverse for OnceLock<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self.get() {
Expand Down
38 changes: 19 additions & 19 deletions src/serializers/type_serializers/nested_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use crate::{
SchemaSerializer,
};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct NestedModelSerializer {
model: Py<PyType>,
name: String,
get_serializer: Py<PyAny>,
serializer: OnceLock<Py<SchemaSerializer>>,
serializer: OnceLock<PyResult<Py<SchemaSerializer>>>,
}

impl_py_gc_traverse!(NestedModelSerializer {
Expand All @@ -41,14 +41,14 @@ impl BuildSerializer for NestedModelSerializer {

let get_serializer = schema
.get_item(intern!(py, "get_info"))?
.expect("Invalid core schema for `nested-model` type")
.expect("Invalid core schema for `nested-model` type, no `get_info`")
.unbind();

let model = schema
.get_item(intern!(py, "model"))?
.expect("Invalid core schema for `nested-model` type")
.expect("Invalid core schema for `nested-model` type, no `model`")
.downcast::<PyType>()
.expect("Invalid core schema for `nested-model` type")
.expect("Invalid core schema for `nested-model` type, not a `PyType`")
.clone();

let name = model.getattr(intern!(py, "__name__"))?.extract()?;
Expand All @@ -63,23 +63,21 @@ impl BuildSerializer for NestedModelSerializer {
}

impl NestedModelSerializer {
fn nested_serializer<'py>(&self, py: Python<'py>) -> Py<SchemaSerializer> {
fn nested_serializer<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaSerializer>> {
self.serializer
.get_or_init(|| {
self.get_serializer
Ok(self
.get_serializer
.bind(py)
.call((), None)
.expect("Invalid core schema for `nested-model`")
.downcast::<PyTuple>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.get_item(2)
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.downcast::<SchemaSerializer>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.call((), None)?
.downcast::<PyTuple>()?
.get_item(2)?
.downcast::<SchemaSerializer>()?
.clone()
.unbind()
.unbind())
})
.clone()
.as_ref()
.map_err(|e| e.clone_ref(py))
}
}

Expand All @@ -93,15 +91,15 @@ impl TypeSerializer for NestedModelSerializer {
) -> PyResult<PyObject> {
let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?;

self.nested_serializer(value.py())
self.nested_serializer(value.py())?
.bind(value.py())
.get()
.serializer
.to_python(value, include, exclude, guard.state())
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.nested_serializer(key.py())
self.nested_serializer(key.py())?
.bind(key.py())
.get()
.serializer
Expand All @@ -123,6 +121,8 @@ impl TypeSerializer for NestedModelSerializer {
.map_err(py_err_se_err)?;

self.nested_serializer(value.py())
// FIXME(BoxyUwU): Don't unwrap this
.unwrap()
.bind(value.py())
.get()
.serializer
Expand Down
45 changes: 22 additions & 23 deletions src/validators/nested_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use crate::{

use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct NestedModelValidator {
model: Py<PyType>,
name: String,
get_validator: Py<PyAny>,
validator: OnceLock<Py<SchemaValidator>>,
validator: OnceLock<PyResult<Py<SchemaValidator>>>,
}

impl_py_gc_traverse!(NestedModelValidator {
Expand All @@ -39,16 +39,12 @@ impl BuildValidator for NestedModelValidator {
) -> PyResult<super::CombinedValidator> {
let py = schema.py();

let get_validator = schema
.get_item(intern!(py, "get_info"))?
.expect("Invalid core schema for `nested-model` type")
.unbind();
let get_validator = schema.get_item(intern!(py, "get_info"))?.unwrap().unbind();

let model = schema
.get_item(intern!(py, "model"))?
.expect("Invalid core schema for `nested-model` type")
.downcast::<PyType>()
.expect("Invalid core schema for `nested-model` type")
.unwrap()
.downcast::<PyType>()?
.clone();

let name = model.getattr(intern!(py, "__name__"))?.extract()?;
Expand All @@ -63,23 +59,21 @@ impl BuildValidator for NestedModelValidator {
}

impl NestedModelValidator {
fn nested_validator<'py>(&self, py: Python<'py>) -> Py<SchemaValidator> {
fn nested_validator<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaValidator>> {
self.validator
.get_or_init(|| {
self.get_validator
Ok(self
.get_validator
.bind(py)
.call((), None)
.expect("Invalid core schema for `nested-model`")
.downcast::<PyTuple>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.get_item(1)
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.downcast::<SchemaValidator>()
.expect("Invalid return value from `nested-model`'s `get_info` callable")
.call((), None)?
.downcast::<PyTuple>()?
.get_item(1)?
.downcast::<SchemaValidator>()?
.clone()
.unbind()
.unbind())
})
.clone()
.as_ref()
.map_err(|e| e.clone_ref(py))
}
}

Expand All @@ -91,15 +85,20 @@ impl Validator for NestedModelValidator {
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let Some(id) = input.as_python().map(py_identity) else {
panic!("")
return self
.nested_validator(py)?
.bind(py)
.get()
.validator
.validate(py, input, state);
};

// Python objects can be cyclic, so need recursion guard
let Ok(mut guard) = RecursionGuard::new(state, id, self.model.as_ptr() as usize) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
};

self.nested_validator(py)
self.nested_validator(py)?
.bind(py)
.get()
.validator
Expand Down

0 comments on commit 40da0ee

Please sign in to comment.