diff --git a/crates/sb_ai/onnxruntime/mod.rs b/crates/sb_ai/onnxruntime/mod.rs index e00fd658..0fcd3fc1 100644 --- a/crates/sb_ai/onnxruntime/mod.rs +++ b/crates/sb_ai/onnxruntime/mod.rs @@ -21,7 +21,7 @@ pub fn op_sb_ai_ort_init_session(#[buffer] model_bytes: &[u8]) -> Result( +pub fn op_sb_ai_ort_run_session( #[string] model_id: String, #[serde] input_values: HashMap, ) -> Result> { @@ -31,7 +31,11 @@ pub fn op_sb_ai_ort_run_session<'a>( // println!("{model_session:?}"); let input_values = input_values .into_iter() - .map(|(key, value)| value.as_ort_input().map(|value| (Cow::from(key), value))) + .map(|(key, value)| { + value + .extract_ort_input() + .map(|value| (Cow::from(key), value)) + }) .collect::>>()?; let mut outputs = model_session.run(input_values)?; diff --git a/crates/sb_ai/onnxruntime/session.rs b/crates/sb_ai/onnxruntime/session.rs index 5b372c8d..ba01afe2 100644 --- a/crates/sb_ai/onnxruntime/session.rs +++ b/crates/sb_ai/onnxruntime/session.rs @@ -135,7 +135,7 @@ pub(crate) fn load_session_from_bytes(model_bytes: &[u8]) -> Result<(String, Arc pub(crate) fn get_session(session_id: &String) -> Option> { let sessions = SESSIONS.lock().unwrap(); - sessions.get(session_id).map(|session| session.clone()) + sessions.get(session_id).cloned() } pub fn cleanup() -> Result { diff --git a/crates/sb_ai/onnxruntime/tensor.rs b/crates/sb_ai/onnxruntime/tensor.rs index 91e03fa0..0f848b07 100644 --- a/crates/sb_ai/onnxruntime/tensor.rs +++ b/crates/sb_ai/onnxruntime/tensor.rs @@ -105,7 +105,7 @@ pub struct JsTensor { } impl JsTensor { - pub fn as_ort_tensor_ref<'a, T: IntoTensorElementType + Debug>( + pub fn extract_ort_tensor_ref<'a, T: IntoTensorElementType + Debug>( mut self, ) -> anyhow::Result> { // Same impl. as the Tensor::from_array() @@ -128,23 +128,23 @@ impl JsTensor { Ok(tensor.into_dyn()) } - pub fn as_ort_input<'a>(self) -> anyhow::Result> { + pub fn extract_ort_input<'a>(self) -> anyhow::Result> { let input_value = match self.data_type { - TensorElementType::Float32 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Float64 => self.as_ort_tensor_ref::()?.into(), + TensorElementType::Float32 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Float64 => self.extract_ort_tensor_ref::()?.into(), TensorElementType::String => { // TODO: Handle string[] tensors from 'v8::Array' return Err(anyhow!("Can't extract tensor from it: 'String' does not implement the 'IntoTensorElementType' trait.")); } - TensorElementType::Int8 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Uint8 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Int16 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Uint16 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Int32 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Uint32 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Int64 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Uint64 => self.as_ort_tensor_ref::()?.into(), - TensorElementType::Bool => self.as_ort_tensor_ref::()?.into(), + TensorElementType::Int8 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Uint8 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Int16 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Uint16 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Int32 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Uint32 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Int64 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Uint64 => self.extract_ort_tensor_ref::()?.into(), + TensorElementType::Bool => self.extract_ort_tensor_ref::()?.into(), TensorElementType::Float16 => { return Err(anyhow!("'half::f16' is not supported by JS tensor.")) } @@ -172,8 +172,7 @@ impl ToJsTensor { let ValueType::Tensor { ty, dimensions } = ort_type else { return Err(anyhow!( "JS only support 'ort::Value' of 'Tensor' type, got '{ort_type:?}'." - ) - .into()); + )); }; let buffer_slice = match ty {