Skip to content

Commit

Permalink
Zune: Fix Output node handling for now
Browse files Browse the repository at this point in the history
  • Loading branch information
saidinesh5 committed Jun 8, 2022
1 parent 41b7bd1 commit 45c6dc6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
9 changes: 8 additions & 1 deletion crates/runtime/examples/zune_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use hotg_rune_runtime::zune::{ElementType, TensorResult, ZuneEngine};
fn main() -> Result<(), Error> {
let args: Vec<String> = std::env::args().collect();

let filename = args.get(1).map(|s| s.as_str()).unwrap_or("sine.rune");
let filename = args.get(1).map(|s| s.as_str()).unwrap_or("/home/helios/Code/hotg/rune/crates/runtime/examples/sine.rune");

let sine_zune = std::fs::read(&filename)
.with_context(|| format!("Unable to read \"{filename}\""))?;
Expand Down Expand Up @@ -47,5 +47,12 @@ fn main() -> Result<(), Error> {
zune_engine.get_output_tensor("sine", "Identity")
);

for node in zune_engine.output_nodes() {
let input_tensor_names = zune_engine.get_input_tensor_names(node)?;
for tensor_name in &input_tensor_names {
println!("Output {:?} {:?}: {:?}", node, tensor_name, zune_engine.get_input_tensor(node, tensor_name));
}
}

Ok(())
}
36 changes: 28 additions & 8 deletions crates/runtime/src/zune/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use zip;

pub use self::{proc_block_v1::*, runtime_v1::*};
use crate::{
zune::proc_block::{GraphContext, ProcBlockNode, TensorConstraint},
zune::proc_block::{GraphContext, ProcBlockNode, TensorConstraint, Dimensions},
LoadError,
};

Expand Down Expand Up @@ -156,7 +156,9 @@ impl ZuneEngine {
let _span =
tracing::debug_span!("Running Stage", %stage_name).entered();

self.nodes.get_mut(stage_name).unwrap().run()?;
if let Some(node) = self.nodes.get_mut(stage_name) {
node.run()?;
}
}
Ok(())
}
Expand Down Expand Up @@ -189,7 +191,7 @@ impl ZuneEngine {
}

pub fn get_input_tensor(
&mut self,
&self,
node_name: &str,
tensor_name: &str,
) -> Option<TensorResult> {
Expand Down Expand Up @@ -331,10 +333,9 @@ fn instantiate_nodes(
shared_state: shared_state.clone(),
};

for item in pipeline {
for (stage_name, stage) in pipeline {
// Collect each output tensor into tensors
let stage_name = item.0;
match item.1 {
match stage {
// Models are handled on the host side, so we treat them separately
Stage::Capability(stage) => {
let wasm =
Expand Down Expand Up @@ -390,8 +391,27 @@ fn instantiate_nodes(
)?;
nodes.insert(stage_name.to_string(), Box::new(pb));
},

_ => {}, // Do nothing for capabilities/outputs
Stage::Out(stage) => {
shared_state
.lock()
.unwrap()
.graph_contexts
.get_mut(stage_name)
.and_then(|c| {
for input in stage.inputs.iter() {
let tensor_key = key(&input.name, input.index);
let tensor_id = output_tensors.get(&tensor_key).copied();
c.input_tensors.insert(tensor_key,
TensorConstraint {
tensor_id,
element_type: ElementType::U8,
dimensions: Dimensions::Dynamic
}
);
}
Some(())
});
}, // Do nothing for capabilities/outputs
}
}

Expand Down

0 comments on commit 45c6dc6

Please sign in to comment.