Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
Merged

allow chat to halt new token generation on stop_sequence #364

merged 13 commits into from
Jul 12, 2023

Conversation

averypelle
Copy link
Contributor

@averypelle averypelle commented Jul 9, 2023

Closes #363

Stop token generation after reaching a specified stop_sequence in chat mode

I am still new to rust so please let me know how I can improve my code!

@averypelle averypelle marked this pull request as draft July 9, 2023 20:06
@@ -264,6 +266,33 @@ fn interactive(
.unwrap_or(false)
}

fn inference_callback(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -256,6 +256,12 @@ fn interactive(
let parameters = generate.inference_parameters(model.eot_token_id());
let mut rng = generate.rng();

let stop_sequence = message_prompt_template
Copy link
Contributor Author

@averypelle averypelle Jul 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming the message_prompt_template is something like:

User: {{PROMPT}}

But very open to suggestions here on how to make this more robust

@averypelle averypelle marked this pull request as ready for review July 9, 2023 20:25
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
if chat_mode {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i didn't touch REPL mode - is the desired behavior to continue generating tokens in that mode? otherwise happy to change

@averypelle averypelle mentioned this pull request Jul 9, 2023
@averypelle averypelle changed the title allow chat to halt new token generation allow chat to halt new token generation on stop_sequence Jul 9, 2023
@philpax
Copy link
Collaborator

philpax commented Jul 9, 2023

Great work! I was actually thinking about bringing that logic out so that I could use it for llmcord.

Do you think you'd be able to move the inference_callback to llm-base, name it something like conversation_inference_callback, and update both llm-cli and the vicuna-chat example to use it? (You might need to parameterise over print_token)

That would allow for this logic to be used across both, as well as elsewhere (the aforementioned llmcord).

I'd suggest passing the stop sequence in from the CLI (i.e. maybe replace message_prompt with message_prompt_prefix and use that as the stop sequence.)

i didn't touch REPL mode - is the desired behavior to continue generating tokens in that mode? otherwise happy to change

Yup, that's for a back and forth where no state is preserved and the model can produce as much output as it wants. I'd suggest splitting the code paths so that they use entirely different inference callbacks - they share the readline, but their inference behaviour is pretty different.

Great work once again, let me know if you need a hand with any of this! 🙂

@philpax philpax added issue:enhancement New feature or request app:cli App: the `llm` CLI labels Jul 9, 2023
@averypelle
Copy link
Contributor Author

Thanks @philpax! Just pushed an update where I moved the function to llm-base. For the message_prompt, is there any case where someone would want a template that includes a postfix? If so, maybe a new option is needed? Otherwise, I can rename to prefix - in this case, would it still make sense for the prefix to include the string {{PROMPT}}?

@philpax
Copy link
Collaborator

philpax commented Jul 11, 2023

Great work! Looking forward to merging this soon 🚀

For the message_prompt, is there any case where someone would want a template that includes a postfix? If so, maybe a new option is needed? Otherwise, I can rename to prefix - in this case, would it still make sense for the prefix to include the string {{PROMPT}}?

I don't think the logic would work if it was postfix anyway - we should make it clear that you need to pass in a prefix. I'd say you can leave out the {{PROMPT}} in that case, because it should always be implied that the prompt will be suffixed.

@averypelle
Copy link
Contributor Author

averypelle commented Jul 12, 2023

Okay @philpax I have updated the CLI to take a message_prompt_prefix instead. I also tried running locally with several models and it is working as expected for a multi-prompt chat.

@philpax
Copy link
Collaborator

philpax commented Jul 12, 2023

Heya! ...apologies for hijacking the PR. I went to test it and all of your changes worked as expected, but I realised that there were quite a few latent bugs with the stuff not covered by your PR and that the whole chat/REPL logic just wasn't working how I wanted it to work. I ended up revising way more than I intended 😅

The upshot is that it should now work consistently, and there shouldn't be any surprise discrepancies between REPL and chat mode. Sorry once again for the complete hijack 😭


Feel free to ask about any of the changes I made! Most of them were unrelated to the code you introduced (I mostly addressed issues that were already present before your changes), but I'm happy to explain them nonetheless. You might be interested in the simplify message_prompt_prefix commit, in which I replaced the if-lets with a match.

@philpax philpax merged commit fc1c052 into rustformers:main Jul 12, 2023
14 checks passed
eyre::bail!(
"Must specify either --message-prompt-prefix or --message-prompt-prefix-file"
)
(None, Some(message_prompt_prefix_file)) => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow very cool!

) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
let mut stop_sequence_buf = String::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, scoping this buffer to the function seems a lot better!

@@ -897,8 +897,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(

/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
stop_sequence: String,
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do these do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust, objects get the Send trait if they can be sent across threads, and Sync if they can be used by multiple threads (you can see more details here).

I needed to add this because eyre, which we use for error reporting in the CLI, expects the error from infer to be Send + Sync. The error is passed down from callback to infer, so the trait requirements need to be updated across the library.

Assistant: How may I help you?
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some prompts I made for Falcon and MPT too since I was testing that. Want me to add in a follow-up PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing!

)
}

fn interactive(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah love that you split these out into separate functions!

@averypelle averypelle deleted the fix/chat-halting branch July 13, 2023 17:57
@hhamud hhamud mentioned this pull request Aug 7, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
app:cli App: the `llm` CLI issue:enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Chat does not halt
3 participants