Skip to content

Commit

Permalink
cleaning up pyo3 conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
rickfast committed Oct 7, 2023
1 parent f4d602f commit 1510219
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
34 changes: 32 additions & 2 deletions ferrix-model-api/src/internal.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use core::panic;
use std::collections::HashMap;

use ferrix_protos::infer_parameter::*;
use ferrix_protos::model_infer_request::*;
use ferrix_protos::model_infer_response::InferOutputTensor;
use ferrix_protos::*;
use pyo3::PyResult;
use pyo3::types::IntoPyDict;
use pyo3::types::PyDict;
use pyo3::types::PyList;
use pyo3::FromPyObject;
use pyo3::IntoPy;
use pyo3::Py;
use pyo3::ToPyObject;
use pyo3::PyAny;

use crate::python::PyInferInput;
use crate::python::PyInferOutput;
Expand Down Expand Up @@ -46,7 +49,8 @@ pub struct InferRequest {
impl ToPyObject for InferRequest {
fn to_object(&self, py: pyo3::Python<'_>) -> pyo3::PyObject {
let parameters: Py<PyDict> = self.parameters.clone().into_py_dict(py).into_py(py);
let inputs = PyList::new(py, self.inputs.clone()).into_py(py);
let x: Vec<Py<PyAny>> = self.inputs.iter().map(|input| input.to_object(py)).collect();
let inputs = PyList::new(py, x).into_py(py);
let outputs = PyList::empty(py).into_py(py);
let raw = PyList::empty(py).into_py(py);
let request = PyInferRequest::new(
Expand Down Expand Up @@ -182,7 +186,7 @@ impl ToPyObject for Parameter {
}
}

#[derive(Clone, FromPyObject)]
#[derive(Clone)]
pub struct InputTensor {
pub name: String,
pub datatype: String,
Expand Down Expand Up @@ -218,6 +222,32 @@ impl InputTensor {
}
}

impl FromPyObject<'_> for InputTensor {
fn extract(ob: &'_ PyAny) -> pyo3::PyResult<Self> {
let mut tensor_data = TensorData::default();
let name: String = ob.getattr("name")?.extract()?;
let datatype: String = ob.getattr("datatype")?.extract()?;
let shape: Vec<i64> = ob.getattr("shape")?.extract()?;
let parameters: HashMap<String, Parameter> = HashMap::new();
let _: Result<(), anyhow::Error> = match datatype.as_str() {
"FP32" => {
tensor_data.fp32_contents = ob.getattr("data")?.extract()?;
Ok(())
},
_ => todo!()
};
let tensor = InputTensor {
name,
datatype,
shape,
parameters,
data: tensor_data
};

return PyResult::Ok(tensor);
}
}

impl ToPyObject for InputTensor {
fn to_object(&self, py: pyo3::Python<'_>) -> pyo3::PyObject {
let parameters: Py<PyDict> = self.parameters.clone().into_py_dict(py).into_py(py);
Expand Down
4 changes: 2 additions & 2 deletions ferrix-python-hooks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ mod tests {
use super::*;

#[test]
#[ignore = "will fix"]
fn test() {
eval();

Expand All @@ -119,8 +118,9 @@ mod tests {
bytes_contents: vec![],
},
}],
// inputs: vec![],
outputs: vec![],
raw_input_contents: vec![],
raw_input_contents: vec![vec![1_u8]],
};

let response = preprocess(infer_request);
Expand Down

0 comments on commit 1510219

Please sign in to comment.