Skip to content

Commit

Permalink
stamp: clippy
Browse files Browse the repository at this point in the history
  • Loading branch information
kallebysantos committed Nov 5, 2024
1 parent 46df317 commit 8a3a36c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
8 changes: 6 additions & 2 deletions crates/sb_ai/onnxruntime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub fn op_sb_ai_ort_init_session(#[buffer] model_bytes: &[u8]) -> Result<ModelIn

#[op2]
#[serde]
pub fn op_sb_ai_ort_run_session<'a>(
pub fn op_sb_ai_ort_run_session(
#[string] model_id: String,
#[serde] input_values: HashMap<String, JsTensor>,
) -> Result<HashMap<String, ToJsTensor>> {
Expand All @@ -31,7 +31,11 @@ pub fn op_sb_ai_ort_run_session<'a>(
// println!("{model_session:?}");
let input_values = input_values
.into_iter()
.map(|(key, value)| value.as_ort_input().map(|value| (Cow::from(key), value)))
.map(|(key, value)| {
value
.extract_ort_input()
.map(|value| (Cow::from(key), value))
})
.collect::<Result<Vec<_>>>()?;

let mut outputs = model_session.run(input_values)?;
Expand Down
2 changes: 1 addition & 1 deletion crates/sb_ai/onnxruntime/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ pub(crate) fn load_session_from_bytes(model_bytes: &[u8]) -> Result<(String, Arc
pub(crate) fn get_session(session_id: &String) -> Option<Arc<Session>> {
let sessions = SESSIONS.lock().unwrap();

sessions.get(session_id).map(|session| session.clone())
sessions.get(session_id).cloned()
}

pub fn cleanup() -> Result<usize, AnyError> {
Expand Down
29 changes: 14 additions & 15 deletions crates/sb_ai/onnxruntime/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub struct JsTensor {
}

impl JsTensor {
pub fn as_ort_tensor_ref<'a, T: IntoTensorElementType + Debug>(
pub fn extract_ort_tensor_ref<'a, T: IntoTensorElementType + Debug>(
mut self,
) -> anyhow::Result<ValueRefMut<'a, DynValueTypeMarker>> {
// Same impl. as the Tensor::from_array()
Expand All @@ -128,23 +128,23 @@ impl JsTensor {
Ok(tensor.into_dyn())
}

pub fn as_ort_input<'a>(self) -> anyhow::Result<SessionInputValue<'a>> {
pub fn extract_ort_input<'a>(self) -> anyhow::Result<SessionInputValue<'a>> {
let input_value = match self.data_type {
TensorElementType::Float32 => self.as_ort_tensor_ref::<f32>()?.into(),
TensorElementType::Float64 => self.as_ort_tensor_ref::<f64>()?.into(),
TensorElementType::Float32 => self.extract_ort_tensor_ref::<f32>()?.into(),
TensorElementType::Float64 => self.extract_ort_tensor_ref::<f64>()?.into(),
TensorElementType::String => {
// TODO: Handle string[] tensors from 'v8::Array'
return Err(anyhow!("Can't extract tensor from it: 'String' does not implement the 'IntoTensorElementType' trait."));
}
TensorElementType::Int8 => self.as_ort_tensor_ref::<i8>()?.into(),
TensorElementType::Uint8 => self.as_ort_tensor_ref::<u8>()?.into(),
TensorElementType::Int16 => self.as_ort_tensor_ref::<i16>()?.into(),
TensorElementType::Uint16 => self.as_ort_tensor_ref::<u16>()?.into(),
TensorElementType::Int32 => self.as_ort_tensor_ref::<i32>()?.into(),
TensorElementType::Uint32 => self.as_ort_tensor_ref::<u32>()?.into(),
TensorElementType::Int64 => self.as_ort_tensor_ref::<i64>()?.into(),
TensorElementType::Uint64 => self.as_ort_tensor_ref::<u64>()?.into(),
TensorElementType::Bool => self.as_ort_tensor_ref::<bool>()?.into(),
TensorElementType::Int8 => self.extract_ort_tensor_ref::<i8>()?.into(),
TensorElementType::Uint8 => self.extract_ort_tensor_ref::<u8>()?.into(),
TensorElementType::Int16 => self.extract_ort_tensor_ref::<i16>()?.into(),
TensorElementType::Uint16 => self.extract_ort_tensor_ref::<u16>()?.into(),
TensorElementType::Int32 => self.extract_ort_tensor_ref::<i32>()?.into(),
TensorElementType::Uint32 => self.extract_ort_tensor_ref::<u32>()?.into(),
TensorElementType::Int64 => self.extract_ort_tensor_ref::<i64>()?.into(),
TensorElementType::Uint64 => self.extract_ort_tensor_ref::<u64>()?.into(),
TensorElementType::Bool => self.extract_ort_tensor_ref::<bool>()?.into(),
TensorElementType::Float16 => {
return Err(anyhow!("'half::f16' is not supported by JS tensor."))
}
Expand Down Expand Up @@ -172,8 +172,7 @@ impl ToJsTensor {
let ValueType::Tensor { ty, dimensions } = ort_type else {
return Err(anyhow!(
"JS only support 'ort::Value' of 'Tensor' type, got '{ort_type:?}'."
)
.into());
));
};

let buffer_slice = match ty {
Expand Down

0 comments on commit 8a3a36c

Please sign in to comment.