diff --git a/ferrix-model-api/src/internal.rs b/ferrix-model-api/src/internal.rs index 833c494..09179f4 100644 --- a/ferrix-model-api/src/internal.rs +++ b/ferrix-model-api/src/internal.rs @@ -429,3 +429,53 @@ fn data_to_py(datatype: String, contents: TensorData, py: pyo3::Python<'_>) -> P _ => todo!(), } } + +#[cfg(test)] +mod tests { + use pyo3::{prepare_freethreaded_python, Py, Python, ToPyObject}; + + use crate::python::PyParameter; + + use super::Parameter; + + fn setup() { + prepare_freethreaded_python(); + } + + #[test] + fn test_parameter_to_py() { + setup(); + + Python::with_gil(|py| { + let result = Parameter { + bool_param: None, + str_param: Some("blah".to_string()), + float_param: None, + int_param: None, + } + .to_object(py); + + assert_eq!( + result + .getattr(py, "str_param") + .unwrap() + .extract::>(py) + .unwrap(), + Some("blah".to_string()) + ) + }) + } + + #[test] + fn test_parameter_from_py() { + setup(); + + Python::with_gil(|py| { + let py_parameter = Py::new(py, PyParameter::new(None, None, None, Some(true))).unwrap(); + let extracted = py_parameter.extract::(py); + + assert!(extracted.is_ok()); + assert!(extracted.unwrap().bool_param.unwrap()); + }); + } +} diff --git a/ferrix-python-hooks/src/lib.rs b/ferrix-python-hooks/src/lib.rs index 21ce80a..6488a4b 100644 --- a/ferrix-python-hooks/src/lib.rs +++ b/ferrix-python-hooks/src/lib.rs @@ -87,8 +87,6 @@ pub fn postprocess(input: InferResponse) -> InferResponse { mod tests { use std::collections::HashMap; - use ferrix_protos::model_infer_request::InferInputTensor; - use super::*; #[test]