Skip to content

Commit

Permalink
Merge pull request #105 from EricLBuehler/non_chat_models
Browse files Browse the repository at this point in the history
Support no chat template
  • Loading branch information
EricLBuehler authored Apr 11, 2024
2 parents 370463c + 84f4792 commit 6667ab8
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ To use a derivative model, select the model architecture using the correct subco

See [this](#adapter-ordering-file) section to determine if it is necessary to prepare an X-LoRA/LoRA ordering file, it is always necessary if the target modules or architecture changed, or if the adapter order changed.

It is also important to check the chat template style of the model. If the HF hub repo has a `tokenizer_config.json` file, it is not necessary to specify. Otherwise, templates can be found in `chat_templates` and should be passed before the subcommand.
It is also important to check the chat template style of the model. If the HF hub repo has a `tokenizer_config.json` file, it is not necessary to specify. Otherwise, templates can be found in `chat_templates` and should be passed before the subcommand. If the model is not instruction tuned, no chat template will be found and the APIs will only accept a prompt, no messages.

For example, when using a Zephyr model:

Expand Down
57 changes: 57 additions & 0 deletions examples/server/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import openai
import httpx
import textwrap, json


def log_response(response: httpx.Response):
request = response.request
print(f"Request: {request.method} {request.url}")
print(" Headers:")
for key, value in request.headers.items():
if key.lower() == "authorization":
value = "[...]"
if key.lower() == "cookie":
value = value.split("=")[0] + "=..."
print(f" {key}: {value}")
print(" Body:")
try:
request_body = json.loads(request.content)
print(textwrap.indent(json.dumps(request_body, indent=2), " "))
except json.JSONDecodeError:
print(textwrap.indent(request.content.decode(), " "))
print(f"Response: status_code={response.status_code}")
print(" Headers:")
for key, value in response.headers.items():
if key.lower() == "set-cookie":
value = value.split("=")[0] + "=..."
print(f" {key}: {value}")


openai.api_key = "EMPTY"
openai.base_url = "http://localhost:1234/v1/"

# Enable this to log requests and responses
# openai.http_client = httpx.Client(
# event_hooks={"request": [print], "response": [log_response]}
# )

eos_toks = ["</s>", "<eos>", "<|endoftext|>"]

while True:
prompt = input(">>> ")
completion = openai.chat.completions.create(
model="mistral",
messages=prompt,
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
)
resp = completion.choices[0].message.content
for eos in eos_toks:
if resp.endswith(eos):
out = resp[: -len(eos)]
print(out)
break
else:
print(resp + "...")
23 changes: 17 additions & 6 deletions integrations/llama_index_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
DEFAULT_REPEAT_LAST_N = 64
DEFAULT_MAX_SEQS = 10


class MistralRS(CustomLLM):
r"""MistralRS LLM.
Expand Down Expand Up @@ -188,7 +189,7 @@ def __init__(
if len(splits) == 1:
model = splits[0]
elif len(splits) == 2 and is_xlora:
model = splits[1]
model = splits[1]
elif len(splits) == 2 and not is_xlora and (is_ggml or is_gguf):
model = splits[0]
elif len(splits) == 2 and is_xlora and (is_ggml or is_gguf):
Expand All @@ -215,7 +216,9 @@ def __init__(
model_id,
no_kv_cache=model_kwargs.get("no_kv_cache", False),
use_flash_attn=True, # will be disabled by &
repeat_last_n=model_kwargs.get("repeat_last_n", DEFAULT_REPEAT_LAST_N),
repeat_last_n=model_kwargs.get(
"repeat_last_n", DEFAULT_REPEAT_LAST_N
),
gqa=model_kwargs.get("gqa", None),
chat_template=model_kwargs.get("chat_template", None),
tokenizer_json=model_kwargs.get("tokenizer_json", None),
Expand All @@ -227,7 +230,9 @@ def __init__(
is_gguf=True,
no_kv_cache=model_kwargs.get("no_kv_cache", False),
use_flash_attn=True, # will be disabled by &
repeat_last_n=model_kwargs.get("repeat_last_n", DEFAULT_REPEAT_LAST_N),
repeat_last_n=model_kwargs.get(
"repeat_last_n", DEFAULT_REPEAT_LAST_N
),
gqa=model_kwargs.get("gqa", None),
quantized_model_id=quantized_model_id,
quantized_filename=quantized_filename,
Expand All @@ -241,7 +246,9 @@ def __init__(
is_gguf=False,
no_kv_cache=model_kwargs.get("no_kv_cache", False),
use_flash_attn=True, # will be disabled by &
repeat_last_n=model_kwargs.get("repeat_last_n", DEFAULT_REPEAT_LAST_N),
repeat_last_n=model_kwargs.get(
"repeat_last_n", DEFAULT_REPEAT_LAST_N
),
gqa=model_kwargs.get("gqa", None),
quantized_model_id=quantized_model_id,
quantized_filename=quantized_filename,
Expand All @@ -254,7 +261,9 @@ def __init__(
model_id,
no_kv_cache=model_kwargs.get("no_kv_cache", False),
use_flash_attn=True, # will be disabled by &
repeat_last_n=model_kwargs.get("repeat_last_n", DEFAULT_REPEAT_LAST_N),
repeat_last_n=model_kwargs.get(
"repeat_last_n", DEFAULT_REPEAT_LAST_N
),
gqa=model_kwargs.get("gqa", None),
order_file=xlora_order_file,
xlora_model_id=xlora_model_id,
Expand Down Expand Up @@ -299,7 +308,9 @@ def __init__(
)

self._runner = loader.load(
token_source=model_kwargs.get("token_source", {"source": "cache"})["source"], # default source is "cache"
token_source=model_kwargs.get("token_source", {"source": "cache"})[
"source"
], # default source is "cache"
max_seqs=model_kwargs.get("max_seqs", DEFAULT_MAX_SEQS),
logfile=None,
revision=model_kwargs.get("revision", None),
Expand Down
23 changes: 23 additions & 0 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,19 @@ impl Engine {
}

fn add_request(&mut self, request: Request) {
if request.messages.is_left()
&& !get_mut_arcmutex!(self.pipeline)
.get_chat_template()
.has_chat_template()
{
// NOTE(EricLBuehler): Unwrap reasoning: The receiver should really be there, otherwise it is their fault.
request
.response
.send(Response::ValidationError(
"Received messages for a model which does not have a chat template. Either use a different model or pass a single string as the prompt".into(),
)).unwrap();
return;
}
let formatted_prompt = match request.messages {
Either::Left(messages) => {
handle_seq_error!(
Expand All @@ -370,6 +383,16 @@ impl Engine {
}
Either::Right(prompt) => prompt,
};
if formatted_prompt.is_empty() {
// NOTE(EricLBuehler): Unwrap reasoning: The receiver should really be there, otherwise it is their fault.
request
.response
.send(Response::ValidationError(
"Received an empty prompt.".into(),
))
.unwrap();
return;
}
let mut prompt = handle_seq_error!(
get_mut_arcmutex!(self.pipeline).tokenize_prompt(&formatted_prompt),
request.response
Expand Down
11 changes: 8 additions & 3 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ pub struct ChatTemplate {
use_default_system_prompt: Option<bool>,
}

impl ChatTemplate {
pub fn has_chat_template(&self) -> bool {
self.chat_template.is_some()
}
}

#[derive(Debug, Clone)]
pub enum TokenSource {
Literal(String),
Expand Down Expand Up @@ -701,12 +707,11 @@ macro_rules! deserialize_chat_template {
}
},
None => {
info!("No specified chat template, loading default chat template at `./default.json`.");
info!("No specified chat template. No chat template will be used. Only prompts will be accepted, not messages.");
deser.insert(
"chat_template".to_string(),
Value::String(fs::read_to_string("./default.json")?),
Value::Null,
);
info!("Default chat template loaded.");
}
};
let ser = serde_json::to_string_pretty(&deser).unwrap();
Expand Down
6 changes: 4 additions & 2 deletions mistralrs-core/src/pipeline/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,10 @@ impl Loader for Phi2Loader {
.map_err(|e| TokenizerError::Error(e.to_string()))?;

let mut chat_template: ChatTemplate = deserialize_chat_template!(paths, self);
chat_template.chat_template = Some("{% for message in messages %}{% if message['role'] == 'system' %}{raise_exception('System prompt not supported')}{% endif %}{% if message['role'] == 'user' %}{{ 'Instruct: '+message['content'] + '\n' }}{% endif %}{% if message['role'] == 'assistant' %}{{ 'Output: '+message['content'] + '\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Output:' }}{% endif %}".to_string());
warn!("The chat template for Phi 2 is being used as: `{:?}`. If this is not desired behavior please raise an issue.", &chat_template.chat_template);
chat_template.chat_template = chat_template.chat_template.map(|_| "{% for message in messages %}{% if message['role'] == 'system' %}{raise_exception('System prompt not supported')}{% endif %}{% if message['role'] == 'user' %}{{ 'Instruct: '+message['content'] + '\n' }}{% endif %}{% if message['role'] == 'assistant' %}{{ 'Output: '+message['content'] + '\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Output:' }}{% endif %}".to_string());
if let Some(c) = chat_template.chat_template.as_ref() {
warn!("The chat template for Phi 2 is being used as: `{:?}`. If this is not desired behavior please raise an issue.", &c);
}

Ok(Box::new(Mutex::new(Phi2Pipeline {
model,
Expand Down

0 comments on commit 6667ab8

Please sign in to comment.