Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
55 changes: 46 additions & 9 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,12 @@ fn interactive(
let parameters = generate.inference_parameters(model.eot_token_id());
let mut rng = generate.rng();

let stop_sequence = message_prompt_template
Copy link
Contributor Author

@averypelle averypelle Jul 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming the message_prompt_template is something like:

User: {{PROMPT}}

But very open to suggestions here on how to make this more robust

.map(|s| s.replace("{{PROMPT}}", "").trim().to_owned())
.unwrap_or_default();

let mut buf = String::new();

fn session_ends_with_newline(session: &InferenceSession) -> bool {
session
.decoded_tokens()
Expand Down Expand Up @@ -293,15 +299,7 @@ fn interactive(
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
inference_callback(stop_sequence.clone(), chat_mode, &mut buf),
)
};

Expand Down Expand Up @@ -448,3 +446,42 @@ impl Validator for LineContinuationValidator {
fn process_prompt(raw_prompt: &str, prompt: &str) -> String {
raw_prompt.replace("{{PROMPT}}", prompt)
}

fn inference_callback(
stop_sequence: String,
chat_mode: bool,
buf: &mut String,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
if chat_mode {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't touch REPL mode - is the desired behavior to continue generating tokens in that mode? otherwise happy to change

let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t)
} else {
print_token(reverse_buf)
}
} else {
print_token(t)
}
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
}
}

fn print_token(t: String) -> Result<llm::InferenceFeedback, Infallible> {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(llm::InferenceFeedback::Continue)
}