diff --git a/crates/runtime/examples/zune_example.rs b/crates/runtime/examples/zune_example.rs index 4e45230363..46019e560c 100644 --- a/crates/runtime/examples/zune_example.rs +++ b/crates/runtime/examples/zune_example.rs @@ -4,7 +4,7 @@ use hotg_rune_runtime::zune::{ElementType, TensorResult, ZuneEngine}; fn main() -> Result<(), Error> { let args: Vec = std::env::args().collect(); - let filename = args.get(1).map(|s| s.as_str()).unwrap_or("sine.rune"); + let filename = args.get(1).map(|s| s.as_str()).unwrap_or("/home/helios/Code/hotg/rune/crates/runtime/examples/sine.rune"); let sine_zune = std::fs::read(&filename) .with_context(|| format!("Unable to read \"{filename}\""))?; @@ -47,5 +47,12 @@ fn main() -> Result<(), Error> { zune_engine.get_output_tensor("sine", "Identity") ); + for node in zune_engine.output_nodes() { + let input_tensor_names = zune_engine.get_input_tensor_names(node)?; + for tensor_name in &input_tensor_names { + println!("Output {:?} {:?}: {:?}", node, tensor_name, zune_engine.get_input_tensor(node, tensor_name)); + } + } + Ok(()) } diff --git a/crates/runtime/src/zune/mod.rs b/crates/runtime/src/zune/mod.rs index 7501538509..8b4d34e043 100644 --- a/crates/runtime/src/zune/mod.rs +++ b/crates/runtime/src/zune/mod.rs @@ -16,7 +16,7 @@ use zip; pub use self::{proc_block_v1::*, runtime_v1::*}; use crate::{ - zune::proc_block::{GraphContext, ProcBlockNode, TensorConstraint}, + zune::proc_block::{GraphContext, ProcBlockNode, TensorConstraint, Dimensions}, LoadError, }; @@ -156,7 +156,9 @@ impl ZuneEngine { let _span = tracing::debug_span!("Running Stage", %stage_name).entered(); - self.nodes.get_mut(stage_name).unwrap().run()?; + if let Some(node) = self.nodes.get_mut(stage_name) { + node.run()?; + } } Ok(()) } @@ -189,7 +191,7 @@ impl ZuneEngine { } pub fn get_input_tensor( - &mut self, + &self, node_name: &str, tensor_name: &str, ) -> Option { @@ -331,10 +333,9 @@ fn instantiate_nodes( shared_state: shared_state.clone(), }; - for item in pipeline { + for (stage_name, stage) in pipeline { // Collect each output tensor into tensors - let stage_name = item.0; - match item.1 { + match stage { // Models are handled on the host side, so we treat them separately Stage::Capability(stage) => { let wasm = @@ -390,8 +391,27 @@ fn instantiate_nodes( )?; nodes.insert(stage_name.to_string(), Box::new(pb)); }, - - _ => {}, // Do nothing for capabilities/outputs + Stage::Out(stage) => { + shared_state + .lock() + .unwrap() + .graph_contexts + .get_mut(stage_name) + .and_then(|c| { + for input in stage.inputs.iter() { + let tensor_key = key(&input.name, input.index); + let tensor_id = output_tensors.get(&tensor_key).copied(); + c.input_tensors.insert(tensor_key, + TensorConstraint { + tensor_id, + element_type: ElementType::U8, + dimensions: Dimensions::Dynamic + } + ); + } + Some(()) + }); + }, // Do nothing for capabilities/outputs } }