Skip to content

Commit

Permalink
[Example] ggml: update gemma example for stream output
Browse files Browse the repository at this point in the history
Signed-off-by: dm4 <[email protected]>
  • Loading branch information
dm4 authored and hydai committed Feb 22, 2024
1 parent 84efc91 commit a3b0685
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions wasmedge-ggml/gemma/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde_json::Value;
use serde_json::json;
use serde_json::Value;
use std::env;
use std::io;
use wasi_nn::{self, GraphExecutionContext};
Expand Down Expand Up @@ -67,8 +67,7 @@ fn get_output_from_context(context: &GraphExecutionContext) -> String {

#[allow(dead_code)]
fn get_metadata_from_context(context: &GraphExecutionContext) -> Value {
serde_json::from_str(&get_data_from_context(context, 1))
.expect("Failed to get metadata")
serde_json::from_str(&get_data_from_context(context, 1)).expect("Failed to get metadata")
}

fn main() {
Expand Down Expand Up @@ -119,15 +118,18 @@ fn main() {
let mut saved_prompt = String::new();

loop {
println!("Question:");
println!("USER:");
let input = read_input();
if saved_prompt.is_empty() {
saved_prompt = format!(
"<start_of_turn>user {} <end_of_turn><start_of_turn>model",
input
);
} else {
saved_prompt = format!("{} <start_of_turn>user {} <end_of_turn><start_of_turn>model", saved_prompt, input);
saved_prompt = format!(
"{} <start_of_turn>user {} <end_of_turn><start_of_turn>model",
saved_prompt, input
);
}

// Set prompt to the input tensor.
Expand All @@ -148,6 +150,7 @@ fn main() {

// Execute the inference.
let mut reset_prompt = false;
println!("ASSISTANT:");
match context.compute() {
Ok(_) => (),
Err(wasi_nn::Error::BackendError(wasi_nn::BackendError::ContextFull)) => {
Expand All @@ -165,7 +168,11 @@ fn main() {

// Retrieve the output.
let mut output = get_output_from_context(&context);
println!("Answer:\n{}", output.trim());
if let Some(true) = options["stream-stdout"].as_bool() {
println!("");
} else {
println!("{}", output.trim());
}

// Update the saved prompt.
if reset_prompt {
Expand Down
Binary file modified wasmedge-ggml/gemma/wasmedge-ggml-gemma.wasm
Binary file not shown.

0 comments on commit a3b0685

Please sign in to comment.