Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
46 changes: 37 additions & 9 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ fn interactive(
let parameters = generate.inference_parameters(model.eot_token_id());
let mut rng = generate.rng();

let mut buf = String::new();

fn session_ends_with_newline(session: &InferenceSession) -> bool {
session
.decoded_tokens()
Expand All @@ -264,6 +266,33 @@ fn interactive(
.unwrap_or(false)
}

fn inference_callback(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stop_sequence: String,
buf: &mut String,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t)
} else {
print_token(reverse_buf)
}
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
}
}

let mut infer = |session: &mut InferenceSession, mut prompt: String| {
// Add a newline to the beginning of the prompt if the last character in the session is not a newline
if !session_ends_with_newline(session) {
Expand Down Expand Up @@ -293,15 +322,7 @@ fn interactive(
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
inference_callback(String::from("User:"), &mut buf),
)
};

Expand Down Expand Up @@ -448,3 +469,10 @@ impl Validator for LineContinuationValidator {
fn process_prompt(raw_prompt: &str, prompt: &str) -> String {
raw_prompt.replace("{{PROMPT}}", prompt)
}

fn print_token(t: String) -> Result<llm::InferenceFeedback, Infallible> {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(llm::InferenceFeedback::Continue)
}