diff --git a/examples/http.md b/examples/http.md index de78961b8..2c1e04253 100644 --- a/examples/http.md +++ b/examples/http.md @@ -74,6 +74,45 @@ Example with `curl`: curl http://localhost:/docs ``` +## `POST`: `/v1/completions` +Process an OpenAI compatible completions request, returning an OpenAI compatible response when finished. Please find the official OpenAI API documentation [here](https://platform.openai.com/docs/api-reference/completions). + +To send a request with the Python `openai` library: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "EMPTY" +) + +completion = client.completions.create( + model="mistral", + prompt="What is Rust?", + max_tokens=256, + frequency_penalty=1.0, + top_p=0.1, + temperature=0, +) + +print(completion.choices[0].message) +``` + +Or with `curl`: +```bash +curl http://localhost:8080/v1/chat/completions \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer EMPTY" \ +-d '{ +"model": "", +"prompt": "What is Rust?" +] +}' +``` + +Streaming requests are not supported. + ## Request ### `ChatCompletionRequest` OpenAI compatible request. @@ -134,7 +173,7 @@ pub struct ChatCompletionResponse { pub model: &'static str, pub system_fingerprint: String, pub object: String, - pub usage: ChatCompletionUsage, + pub usage: Usage, } ``` @@ -186,9 +225,9 @@ pub struct TopLogprob { } ``` -### `ChatCompletionUsage` +### `Usage` ```rust -pub struct ChatCompletionUsage { +pub struct Usage { pub completion_tokens: usize, pub prompt_tokens: usize, pub total_tokens: usize, diff --git a/examples/server/prompt.py b/examples/server/completion.py similarity index 92% rename from examples/server/prompt.py rename to examples/server/completion.py index 0da715542..33a5530d1 100644 --- a/examples/server/prompt.py +++ b/examples/server/completion.py @@ -39,15 +39,15 @@ def log_response(response: httpx.Response): while True: prompt = input(">>> ") - completion = openai.chat.completions.create( + completion = openai.completions.create( model="mistral", - messages=prompt, + prompt=prompt, max_tokens=256, frequency_penalty=1.0, top_p=0.1, temperature=0, ) - resp = completion.choices[0].message.content + resp = completion.choices[0].text for eos in eos_toks: if resp.endswith(eos): out = resp[: -len(eos)] diff --git a/examples/server/streaming.py b/examples/server/streaming.py index b6a762a2f..5fb0cd838 100644 --- a/examples/server/streaming.py +++ b/examples/server/streaming.py @@ -1,9 +1,10 @@ import openai +import sys openai.api_key = "EMPTY" openai.base_url = "http://localhost:1234/v1/" -# """ + messages = [] prompt = input("Enter system prompt >>> ") if len(prompt) > 0: @@ -25,6 +26,7 @@ delta = chunk.choices[0].delta.content if delta not in eos_toks: print(delta, end="") + sys.stdout.flush() resp += delta for eos in eos_toks: if resp.endswith(eos): diff --git a/integrations/llama_index_integration.py b/integrations/llama_index_integration.py index e12dd6b98..c3a52969e 100644 --- a/integrations/llama_index_integration.py +++ b/integrations/llama_index_integration.py @@ -365,7 +365,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: top_k=self.generate_kwargs["top_k"], top_p=self.generate_kwargs["top_p"], presence_penalty=self.generate_kwargs.get("presence_penalty", None), - repetition_penalty=self.generate_kwargs.get("repetition_penalty", None), + frequency_penalty=self.generate_kwargs.get("frequency_penalty", None), temperature=self.generate_kwargs.get("temperature", None), ) completion_response = self._runner.send_chat_completion_request(request) @@ -399,7 +399,7 @@ def complete( top_k=self.generate_kwargs["top_k"], top_p=self.generate_kwargs["top_p"], presence_penalty=self.generate_kwargs.get("presence_penalty", None), - repetition_penalty=self.generate_kwargs.get("repetition_penalty", None), + frequency_penalty=self.generate_kwargs.get("frequency_penalty", None), temperature=self.generate_kwargs.get("temperature", None), ) completion_response = self._runner.send_chat_completion_request(request) diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 383950d5e..24f170881 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -7,7 +7,11 @@ use std::{ time::{Instant, SystemTime, UNIX_EPOCH}, }; -use crate::aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}; +use crate::{ + aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}, + response::CompletionChoice, + CompletionResponse, RequestType, +}; use candle_core::Tensor; use either::Either; use tracing::warn; @@ -149,7 +153,7 @@ impl Engine { seq.add_token(next_token.clone()); let is_done = seq.is_done(next_token_id, eos_tok, pipeline.get_max_seq_len()); // Handle streaming requests - if seq.get_mut_group().is_streaming { + if seq.get_mut_group().is_streaming && seq.get_mut_group().is_chat { let tokenizer = pipeline.tokenizer().clone(); if let Some(delta) = handle_seq_error!(seq.get_delta(&tokenizer), seq.responder()) { seq.add_streaming_chunk_choice_to_group(ChunkChoice { @@ -215,30 +219,55 @@ impl Engine { seq.responder() ); - let choice = Choice { - stopreason: reason.to_string(), - index: seq.get_response_index(), - message: ResponseMessage { - content: res, - role: "assistant".to_string(), - }, - logprobs: logprobs.map(|l| Logprobs { content: Some(l) }), - }; - seq.add_choice_to_group(choice); + if seq.get_mut_group().is_chat { + let choice = Choice { + stopreason: reason.to_string(), + index: seq.get_response_index(), + message: ResponseMessage { + content: res, + role: "assistant".to_string(), + }, + logprobs: logprobs.map(|l| Logprobs { content: Some(l) }), + }; + seq.add_choice_to_group(choice); + } else { + let choice = CompletionChoice { + stopreason: reason.to_string(), + index: seq.get_response_index(), + text: res, + logprobs: None, + }; + seq.add_completion_choice_to_group(choice); + } let group = seq.get_mut_group(); - group.maybe_send_done_response( - ChatCompletionResponse { - id: seq.id().to_string(), - choices: group.get_choices().to_vec(), - created: seq.creation_time(), - model: pipeline.name(), - system_fingerprint: SYSTEM_FINGERPRINT.to_string(), - object: "chat.completion".to_string(), - usage: group.get_usage(), - }, - seq.responder(), - ); + if group.is_chat { + group.maybe_send_done_response( + ChatCompletionResponse { + id: seq.id().to_string(), + choices: group.get_choices().to_vec(), + created: seq.creation_time(), + model: pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "chat.completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ); + } else { + group.maybe_send_completion_done_response( + CompletionResponse { + id: seq.id().to_string(), + choices: group.get_completion_choices().to_vec(), + created: seq.creation_time(), + model: pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "text_completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ); + } } /// Clone the cache FROM the sequences' cache TO the model cache. Only used for completion seqs. @@ -480,6 +509,8 @@ impl Engine { let group = Rc::new(RefCell::new(SequenceGroup::new( request.sampling_params.n_choices, request.is_streaming, + request.request_type == RequestType::Chat, + request.best_of, ))); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -525,6 +556,16 @@ impl Engine { response_index, now.as_secs(), recognizer.clone(), + request.suffix.clone(), + if let RequestType::Completion { echo_prompt } = request.request_type.clone() { + if echo_prompt { + Some(formatted_prompt.clone()) + } else { + None + } + } else { + None + }, ); let seq = if let Some(prefill_cache) = prefill_cache.clone() { match prefill_cache { diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index a20d6bd3a..9ba27dce4 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -35,9 +35,9 @@ pub use pipeline::{ MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, ModelKind, Phi2Loader, Phi2SpecificConfig, TokenSource, }; -pub use request::{Constraint, Request}; +pub use request::{Constraint, Request, RequestType}; pub use response::Response; -pub use response::{ChatCompletionResponse, ChatCompletionUsage}; +pub use response::{ChatCompletionResponse, CompletionResponse, Usage}; pub use sampler::{SamplingParams, StopTokens}; pub use scheduler::SchedulerMethod; use serde::Serialize; diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 0a9654640..c2aad8c31 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -10,6 +10,12 @@ pub enum Constraint { None, } +#[derive(Debug, PartialEq, Clone)] +pub enum RequestType { + Chat, + Completion { echo_prompt: bool }, +} + pub struct Request { pub messages: Either>, String>, pub sampling_params: SamplingParams, @@ -18,14 +24,17 @@ pub struct Request { pub is_streaming: bool, pub id: usize, pub constraint: Constraint, + pub request_type: RequestType, + pub suffix: Option, + pub best_of: Option, } impl Debug for Request { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Request {} {{ messages: `{:?}`, sampling_params: {:?}}}", - self.id, self.messages, self.sampling_params + "Request {} ({:?}) {{ messages: `{:?}`, sampling_params: {:?}}}", + self.id, self.request_type, self.messages, self.sampling_params ) } } diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index 07355cb5d..7bf864752 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -50,7 +50,7 @@ pub struct ChunkChoice { } #[derive(Debug, Clone, Serialize)] -pub struct ChatCompletionUsage { +pub struct Usage { pub completion_tokens: usize, pub prompt_tokens: usize, pub total_tokens: usize, @@ -72,7 +72,7 @@ pub struct ChatCompletionResponse { pub model: String, pub system_fingerprint: String, pub object: String, - pub usage: ChatCompletionUsage, + pub usage: Usage, } #[derive(Debug, Clone, Serialize)] @@ -85,10 +85,34 @@ pub struct ChatCompletionChunkResponse { pub object: String, } +#[derive(Debug, Clone, Serialize)] +pub struct CompletionChoice { + #[serde(rename = "finish_reason")] + pub stopreason: String, + pub index: usize, + pub text: String, + pub logprobs: Option<()>, +} + +#[derive(Debug, Clone, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub choices: Vec, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub object: String, + pub usage: Usage, +} + pub enum Response { InternalError(Box), ValidationError(Box), + // Chat ModelError(String, ChatCompletionResponse), Done(ChatCompletionResponse), Chunk(ChatCompletionChunkResponse), + // Completion + CompletionModelError(String, CompletionResponse), + CompletionDone(CompletionResponse), } diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 34728dec1..606ffcfb3 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -5,13 +5,17 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; -use crate::aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}; +use crate::{ + aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}, + response::CompletionChoice, + CompletionResponse, +}; use crate::{ get_mut_group, models::LayerCaches, response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT}, sampler::{Logprobs, Sampler}, - ChatCompletionResponse, ChatCompletionUsage, + ChatCompletionResponse, Usage, }; use candle_core::Tensor; use regex_automata::util::primitives::StateID; @@ -65,6 +69,8 @@ pub struct Sequence { response_index: usize, creation_time: u64, prefill_prompt_toks: Option>, + pub suffix: Option, + pub prefix: Option, // Cache scaling_cache: Option, @@ -75,6 +81,7 @@ pub struct Sequence { tokens: Vec, decoded_tokens: Option>, logprobs: Vec, + cumulative_logprob: f32, // GPU things pub prompt_tok_per_sec: f32, @@ -104,6 +111,8 @@ impl Sequence { response_index: usize, creation_time: u64, recognizer: SequenceRecognizer, + suffix: Option, + prefix: Option, ) -> Self { let prompt_len = tokens.len(); Self { @@ -135,6 +144,9 @@ impl Sequence { creation_time, recognizer, prefill_prompt_toks: None, + suffix, + prefix, + cumulative_logprob: 0., } } @@ -212,6 +224,7 @@ impl Sequence { } pub fn add_token(&mut self, tok: Logprobs) { + self.cumulative_logprob += tok.logprob; self.tokens.push(tok.token); self.logprobs.push(tok); } @@ -303,9 +316,7 @@ impl Sequence { self.prompt_timestamp } - pub fn add_choice_to_group(&self, choice: Choice) { - get_mut_group!(self).choices.push(choice); - + fn update_time_info(&self) { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Time travel has occurred!") @@ -324,6 +335,18 @@ impl Sequence { get_mut_group!(self).total_sampling_time += self.total_sampling_time; } + pub fn add_choice_to_group(&self, choice: Choice) { + get_mut_group!(self).choices.push(choice); + self.update_time_info(); + } + + pub fn add_completion_choice_to_group(&self, choice: CompletionChoice) { + get_mut_group!(self) + .completion_choices + .push((self.cumulative_logprob, choice)); + self.update_time_info(); + } + pub fn get_response_index(&self) -> usize { self.response_index } @@ -339,6 +362,7 @@ impl Sequence { pub struct SequenceGroup { n_choices: usize, // The target number of choices to return. Can be decreased if an error is thrown. + best_of: Option, // Top n seqs based on cumulative logprobs. pub total_prompt_toks: usize, pub total_toks: usize, pub total_prompt_time: u128, @@ -346,14 +370,22 @@ pub struct SequenceGroup { pub total_completion_time: u128, pub total_sampling_time: u128, choices: Vec, + completion_choices: Vec<(f32, CompletionChoice)>, pub streaming_chunks: Vec, pub is_streaming: bool, + pub is_chat: bool, } impl SequenceGroup { - pub fn new(n_choices: usize, is_streaming: bool) -> Self { + pub fn new( + n_choices: usize, + is_streaming: bool, + is_chat: bool, + best_of: Option, + ) -> Self { Self { choices: Vec::new(), + completion_choices: Vec::new(), n_choices, total_prompt_toks: 0, total_toks: 0, @@ -363,16 +395,31 @@ impl SequenceGroup { total_sampling_time: 0, streaming_chunks: Vec::new(), is_streaming, + is_chat, + best_of, } } + /// This does not apply best_of. pub fn get_choices(&self) -> &[Choice] { &self.choices } - pub fn get_usage(&self) -> ChatCompletionUsage { + /// This applies the best_of. + pub fn get_completion_choices(&self) -> Vec { + let mut choices = self.completion_choices.clone(); + // Sort by descending logprobs + choices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); + choices + .into_iter() + .take(self.best_of.unwrap()) + .map(|(_, x)| x) + .collect::>() + } + + pub fn get_usage(&self) -> Usage { #[allow(clippy::cast_precision_loss)] - ChatCompletionUsage { + Usage { completion_tokens: self.total_toks - self.total_prompt_toks, prompt_tokens: self.total_prompt_toks, total_tokens: self.total_toks, @@ -417,4 +464,15 @@ impl SequenceGroup { self.streaming_chunks.clear(); } } + + pub fn maybe_send_completion_done_response( + &self, + response: CompletionResponse, + sender: Sender, + ) { + if self.completion_choices.len() == self.n_choices { + // NOTE(EricLBuehler): Unwrap reasoning: The receiver should really be there, otherwise it is their fault. + sender.send(Response::CompletionDone(response)).unwrap(); + } + } } diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index a4d982be5..df3d97128 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -65,37 +65,67 @@ macro_rules! handle_pipeline_forward_error { Ok(v) => v, Err(_) => "".to_string(), }; - let choice = Choice { - stopreason: "error".to_string(), - index: seq.get_response_index(), - message: ResponseMessage { - content: res, - role: "assistant".to_string(), - }, - logprobs: None, - }; - seq.add_choice_to_group(choice); + + if seq.get_mut_group().is_chat { + let choice = Choice { + stopreason: "error".to_string(), + index: seq.get_response_index(), + message: ResponseMessage { + content: res, + role: "assistant".to_string(), + }, + logprobs: None, + }; + seq.add_choice_to_group(choice); + } else { + let choice = CompletionChoice { + stopreason: "error".to_string(), + index: seq.get_response_index(), + text: res, + logprobs: None, + }; + seq.add_completion_choice_to_group(choice); + } } for seq in $seq_slice.iter_mut() { // Step 2: Respond with all groups let group = seq.get_mut_group(); - let partial_completion_response = ChatCompletionResponse { - id: seq.id().to_string(), - choices: group.get_choices().to_vec(), - created: seq.creation_time(), - model: $pipeline.name(), - system_fingerprint: SYSTEM_FINGERPRINT.to_string(), - object: "chat.completion".to_string(), - usage: group.get_usage(), - }; + if group.is_chat { + let partial_completion_response = ChatCompletionResponse { + id: seq.id().to_string(), + choices: group.get_choices().to_vec(), + created: seq.creation_time(), + model: $pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "chat.completion".to_string(), + usage: group.get_usage(), + }; + + seq.responder() + .send(Response::ModelError( + e.to_string(), + partial_completion_response + )) + .unwrap(); + } else { + let partial_completion_response = CompletionResponse { + id: seq.id().to_string(), + choices: group.get_completion_choices().to_vec(), + created: seq.creation_time(), + model: $pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "text_completion".to_string(), + usage: group.get_usage(), + }; - seq.responder() - .send(Response::ModelError( - e.to_string(), - partial_completion_response - )) - .unwrap(); + seq.responder() + .send(Response::CompletionModelError( + e.to_string(), + partial_completion_response + )) + .unwrap(); + } } for seq in $seq_slice.iter_mut() { // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index 25522aac3..e58e97b02 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -95,13 +95,13 @@ Request is a class with a constructor which accepts the following arguments. It - `max_tokens: usize | None` - `n_choices: usize` - `presence_penalty: float | None` -- `repetition_penalty: float | None` +- `frequency_penalty: float | None` - `stop_token_ids: list[int] | None` - `temperature: float | None` - `top_p: float | None` - `top_k: usize | None` -`ChatCompletionRequest(messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, repetition_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None)` +`ChatCompletionRequest(messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, frequency_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None)` ## `ModelKind` - Normal diff --git a/mistralrs-pyo3/README.md b/mistralrs-pyo3/README.md index 9a8a722f8..c35b020b1 100644 --- a/mistralrs-pyo3/README.md +++ b/mistralrs-pyo3/README.md @@ -45,7 +45,7 @@ res = runner.send_chat_completion_request( {"role": "user", "content": "Tell me a story about the Rust type system."} ], max_tokens=256, - repetition_penalty=1.0, + frequency_penalty=1.0, top_p=0.1, temperature=0.1, ) diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index ea26ffe2a..ab37ad7ba 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -16,7 +16,7 @@ class ChatCompletionRequest: max_tokens: int | None = None n_choices: int = 1 presence_penalty: float | None = None - repetition_penalty: float | None = None + frequency_penalty: float | None = None stop_token_ids: list[int] | None = None temperature: float | None = None top_p: float | None = None @@ -25,6 +25,32 @@ class ChatCompletionRequest: grammar: str | None = None grammar_type: str | None = None +@dataclass +class CompletionRequest: + """ + A CompletionRequest represents a request sent to the mistral.rs engine. It encodes information + about input data, sampling, and how to return the response. + """ + + prompt: str + model: str + best_of: int + echo_prompt: bool + logit_bias: dict[int, float] | None = None + max_tokens: int | None = None + n_choices: int = 1 + best_of: int = 1 + presence_penalty: float | None = None + frequency_penalty: float | None = None + stop_token_ids: list[int] | None = None + temperature: float | None = None + top_p: float | None = None + stream: bool = False + top_k: int | None = None + suffix: str | None = None + grammar: str | None = None + grammar_type: str | None = None + class Runner: """ The Runner is a class with no constructor. It is only created via one of the loader classes. diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index d0a0bef91..59d0b1193 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -9,7 +9,7 @@ use std::{ }; use ::mistralrs::{ - Constraint, MistralRs, Request as _Request, Response, SamplingParams, StopTokens, + Constraint, MistralRs, Request as _Request, RequestType, Response, SamplingParams, StopTokens, }; use candle_core::Device; use loaders::{ @@ -135,7 +135,7 @@ impl Runner { top_k: request.top_k, top_p: request.top_p, top_n_logprobs: request.top_logprobs.unwrap_or(1), - frequency_penalty: request.repetition_penalty, + frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, max_len: request.max_tokens, stop_toks, @@ -146,6 +146,9 @@ impl Runner { return_logprobs: request.logprobs, is_streaming: request.stream, constraint, + request_type: RequestType::Chat, + suffix: None, + best_of: None, }; MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); @@ -161,13 +164,164 @@ impl Runner { MistralRs::maybe_log_response(self.runner.clone(), &response); Ok(serde_json::to_string(&response).unwrap()) } + Response::ModelError(msg, _) => Err(PyValueError::new_err(msg.to_string())), Response::Chunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), + } + }) + } + + /// Send an OpenAI API compatible request, returning raw JSON. + fn send_completion_request(&mut self, request: Py) -> PyResult { + let (tx, rx) = channel(); + Python::with_gil(|py| { + let request = request.bind(py).borrow(); + let stop_toks = request + .stop_token_ids + .as_ref() + .map(|x| StopTokens::Ids(x.to_vec())); + let constraint = if request.grammar_type == Some("regex".to_string()) { + if request.grammar.is_none() { + return Err(PyValueError::new_err( + "Grammar type is specified but not grammar text", + )); + } + Constraint::Regex(request.grammar.as_ref().unwrap().clone()) + } else if request.grammar_type == Some("yacc".to_string()) { + if request.grammar.is_none() { + return Err(PyValueError::new_err( + "Grammar type is specified but not grammar text", + )); + } + Constraint::Yacc(request.grammar.as_ref().unwrap().clone()) + } else if request.grammar_type.is_some() { + return Err(PyValueError::new_err( + "Grammar type is specified but is not `regex` or `yacc`", + )); + } else { + Constraint::None + }; + let model_request = _Request { + id: { + let l = NEXT_REQUEST_ID.lock().unwrap(); + let last = &mut *l.borrow_mut(); + let last_v = *last; + *last += 1; + last_v + }, + messages: Either::Right(request.prompt.clone()), + sampling_params: SamplingParams { + temperature: request.temperature, + top_k: request.top_k, + top_p: request.top_p, + top_n_logprobs: 1, + frequency_penalty: request.frequency_penalty, + presence_penalty: request.presence_penalty, + max_len: request.max_tokens, + stop_toks, + logits_bias: request.logit_bias.clone(), + n_choices: request.n_choices, + }, + response: tx, + return_logprobs: false, + is_streaming: false, + constraint, + request_type: RequestType::Completion { + echo_prompt: request.echo_prompt, + }, + suffix: request.suffix.clone(), + best_of: Some(request.best_of), + }; + + MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); + let sender = self.runner.get_sender(); + sender.send(model_request).unwrap(); + let response = rx.recv().unwrap(); + + match response { + Response::ValidationError(e) | Response::InternalError(e) => { + Err(PyValueError::new_err(e.to_string())) + } + Response::Done(response) => { + MistralRs::maybe_log_response(self.runner.clone(), &response); + Ok(serde_json::to_string(&response).unwrap()) + } Response::ModelError(msg, _) => Err(PyValueError::new_err(msg.to_string())), + Response::Chunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), } }) } } +#[pyclass] +#[derive(Debug)] +/// An OpenAI API compatible completion request. +struct CompletionRequest { + _model: String, + prompt: String, + best_of: usize, + echo_prompt: bool, + presence_penalty: Option, + frequency_penalty: Option, + logit_bias: Option>, + max_tokens: Option, + n_choices: usize, + stop_token_ids: Option>, + temperature: Option, + top_p: Option, + suffix: Option, + top_k: Option, + grammar: Option, + grammar_type: Option, +} + +#[pymethods] +impl CompletionRequest { + #[new] + #[pyo3(signature = (prompt, model, best_of = 1, echo_prompt = false, presence_penalty=None,frequency_penalty=None,logit_bias=None,max_tokens=None,n_choices=1,stop_token_ids=None,temperature=None,top_p=None,suffix=None,top_k=None, grammar = None, grammar_type = None))] + #[allow(clippy::too_many_arguments)] + fn new( + prompt: String, + model: String, + best_of: usize, + echo_prompt: bool, + presence_penalty: Option, + frequency_penalty: Option, + logit_bias: Option>, + max_tokens: Option, + n_choices: usize, + stop_token_ids: Option>, + temperature: Option, + top_p: Option, + suffix: Option, + top_k: Option, + grammar: Option, + grammar_type: Option, + ) -> PyResult { + Ok(Self { + prompt, + best_of, + echo_prompt, + suffix, + _model: model, + logit_bias, + max_tokens, + n_choices, + presence_penalty, + frequency_penalty, + stop_token_ids, + temperature, + top_p, + top_k, + grammar, + grammar_type, + }) + } +} + #[pyclass] #[derive(Debug)] /// An OpenAI API compatible chat completion request. @@ -180,7 +334,7 @@ struct ChatCompletionRequest { max_tokens: Option, n_choices: usize, presence_penalty: Option, - repetition_penalty: Option, + frequency_penalty: Option, stop_token_ids: Option>, temperature: Option, top_p: Option, @@ -193,7 +347,7 @@ struct ChatCompletionRequest { #[pymethods] impl ChatCompletionRequest { #[new] - #[pyo3(signature = (messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, repetition_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None, stream=false, grammar = None, grammar_type = None))] + #[pyo3(signature = (messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, frequency_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None, stream=false, grammar = None, grammar_type = None))] #[allow(clippy::too_many_arguments)] fn new( messages: Py, @@ -204,7 +358,7 @@ impl ChatCompletionRequest { top_logprobs: Option, max_tokens: Option, presence_penalty: Option, - repetition_penalty: Option, + frequency_penalty: Option, stop_token_ids: Option>, temperature: Option, top_p: Option, @@ -253,7 +407,7 @@ impl ChatCompletionRequest { max_tokens, n_choices, presence_penalty, - repetition_penalty, + frequency_penalty, stop_token_ids, temperature, top_p, diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs new file mode 100644 index 000000000..053763e39 --- /dev/null +++ b/mistralrs-server/src/chat_completion.rs @@ -0,0 +1,266 @@ +use std::{ + env, + error::Error, + pin::Pin, + sync::{ + mpsc::{channel, Receiver, Sender}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + +use crate::openai::{ChatCompletionRequest, Grammar, StopTokens}; +use anyhow::Result; +use axum::{ + extract::{Json, State}, + http::{self, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Sse, + }, +}; +use either::Either; +use indexmap::IndexMap; +use mistralrs_core::{ + ChatCompletionResponse, Constraint, MistralRs, Request, RequestType, Response, SamplingParams, + StopTokens as InternalStopTokens, +}; +use serde::Serialize; + +#[derive(Debug)] +struct ModelErrorMessage(String); +impl std::fmt::Display for ModelErrorMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} +impl std::error::Error for ModelErrorMessage {} +pub struct Streamer { + rx: Receiver, + is_done: bool, + state: Arc, +} + +impl futures::Stream for Streamer { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_done { + return Poll::Ready(None); + } + match self.rx.try_recv() { + Ok(resp) => match resp { + Response::ModelError(msg, _) => { + MistralRs::maybe_log_error( + self.state.clone(), + &ModelErrorMessage(msg.to_string()), + ); + Poll::Ready(Some(Ok(Event::default().data(msg)))) + } + Response::ValidationError(e) => { + Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) + } + Response::InternalError(e) => { + MistralRs::maybe_log_error(self.state.clone(), &*e); + Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) + } + Response::Chunk(response) => { + if response.choices.iter().all(|x| x.stopreason.is_some()) { + self.is_done = true; + } + MistralRs::maybe_log_response(self.state.clone(), &response); + Poll::Ready(Some(Event::default().json_data(response))) + } + Response::Done(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), + }, + Err(_) => Poll::Pending, + } + } +} + +pub enum ChatCompletionResponder { + Sse(Sse), + Json(ChatCompletionResponse), + ModelError(String, ChatCompletionResponse), + InternalError(Box), + ValidationError(Box), +} + +trait ErrorToResponse: Serialize { + fn to_response(&self, code: StatusCode) -> axum::response::Response { + let mut r = Json(self).into_response(); + *r.status_mut() = code; + r + } +} + +#[derive(Serialize)] +struct JsonError { + message: String, +} + +impl JsonError { + fn new(message: String) -> Self { + Self { message } + } +} +impl ErrorToResponse for JsonError {} + +#[derive(Serialize)] +struct JsonModelError { + message: String, + partial_response: ChatCompletionResponse, +} + +impl JsonModelError { + fn new(message: String, partial_response: ChatCompletionResponse) -> Self { + Self { + message, + partial_response, + } + } +} + +impl ErrorToResponse for JsonModelError {} + +impl IntoResponse for ChatCompletionResponder { + fn into_response(self) -> axum::response::Response { + match self { + ChatCompletionResponder::Sse(s) => s.into_response(), + ChatCompletionResponder::Json(s) => Json(s).into_response(), + ChatCompletionResponder::InternalError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) + } + ChatCompletionResponder::ValidationError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY) + } + ChatCompletionResponder::ModelError(msg, response) => { + JsonModelError::new(msg, response) + .to_response(http::StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +fn parse_request( + oairequest: ChatCompletionRequest, + state: Arc, + tx: Sender, +) -> Request { + let repr = serde_json::to_string(&oairequest).unwrap(); + MistralRs::maybe_log_request(state.clone(), repr); + + let stop_toks = match oairequest.stop_seqs { + Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), + Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), + Some(StopTokens::MultiId(m)) => Some(InternalStopTokens::Ids(m)), + Some(StopTokens::SingleId(s)) => Some(InternalStopTokens::Ids(vec![s])), + None => None, + }; + let messages = match oairequest.messages { + Either::Left(req_messages) => { + let mut messages = Vec::new(); + for message in req_messages { + let mut message_map = IndexMap::new(); + message_map.insert("role".to_string(), message.role); + message_map.insert("content".to_string(), message.content); + messages.push(message_map); + } + Either::Left(messages) + } + Either::Right(prompt) => Either::Right(prompt), + }; + + Request { + id: state.next_request_id(), + messages, + sampling_params: SamplingParams { + temperature: oairequest.temperature, + top_k: oairequest.top_k, + top_p: oairequest.top_p, + top_n_logprobs: oairequest.top_logprobs.unwrap_or(1), + frequency_penalty: oairequest.frequency_penalty, + presence_penalty: oairequest.presence_penalty, + max_len: oairequest.max_tokens, + stop_toks, + logits_bias: oairequest.logit_bias, + n_choices: oairequest.n_choices, + }, + response: tx, + return_logprobs: oairequest.logprobs, + is_streaming: oairequest.stream.unwrap_or(false), + suffix: None, + best_of: None, + + constraint: match oairequest.grammar { + Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), + Some(Grammar::Regex(regex)) => Constraint::Regex(regex), + None => Constraint::None, + }, + + request_type: RequestType::Chat, + } +} + +#[utoipa::path( + post, + tag = "Mistral.rs", + path = "/v1/chat/completions", + request_body = ChatCompletionRequest, + responses((status = 200, description = "Chat completions")) +)] +pub async fn chatcompletions( + State(state): State>, + Json(oairequest): Json, +) -> ChatCompletionResponder { + let (tx, rx) = channel(); + let request = parse_request(oairequest, state.clone(), tx); + let is_streaming = request.is_streaming; + let sender = state.get_sender(); + sender.send(request).unwrap(); + + if is_streaming { + let streamer = Streamer { + rx, + is_done: false, + state, + }; + + ChatCompletionResponder::Sse( + Sse::new(streamer).keep_alive( + KeepAlive::new() + .interval(Duration::from_millis( + env::var("KEEP_ALIVE_INTERVAL") + .map(|val| val.parse::().unwrap_or(1000)) + .unwrap_or(1000), + )) + .text("keep-alive-text"), + ), + ) + } else { + let response = rx.recv().unwrap(); + + match response { + Response::InternalError(e) => { + MistralRs::maybe_log_error(state, &*e); + ChatCompletionResponder::InternalError(e) + } + Response::ModelError(msg, response) => { + MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); + MistralRs::maybe_log_response(state, &response); + ChatCompletionResponder::ModelError(msg, response) + } + Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e), + Response::Done(response) => { + MistralRs::maybe_log_response(state, &response); + ChatCompletionResponder::Json(response) + } + Response::Chunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), + } + } +} diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs new file mode 100644 index 000000000..a19feb55b --- /dev/null +++ b/mistralrs-server/src/completions.rs @@ -0,0 +1,199 @@ +use std::{ + error::Error, + sync::{ + mpsc::{channel, Sender}, + Arc, + }, +}; + +use crate::openai::{CompletionRequest, Grammar, StopTokens}; +use axum::{ + extract::{Json, State}, + http::{self, StatusCode}, + response::IntoResponse, +}; +use either::Either; +use mistralrs_core::{ + CompletionResponse, Constraint, MistralRs, Request, RequestType, Response, SamplingParams, + StopTokens as InternalStopTokens, +}; +use serde::Serialize; +use tracing::warn; + +#[derive(Debug)] +struct ModelErrorMessage(String); +impl std::fmt::Display for ModelErrorMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} +impl std::error::Error for ModelErrorMessage {} +pub enum CompletionResponder { + Json(CompletionResponse), + ModelError(String, CompletionResponse), + InternalError(Box), + ValidationError(Box), +} + +trait ErrorToResponse: Serialize { + fn to_response(&self, code: StatusCode) -> axum::response::Response { + let mut r = Json(self).into_response(); + *r.status_mut() = code; + r + } +} + +#[derive(Serialize)] +struct JsonError { + message: String, +} + +impl JsonError { + fn new(message: String) -> Self { + Self { message } + } +} +impl ErrorToResponse for JsonError {} + +#[derive(Serialize)] +struct JsonModelError { + message: String, + partial_response: CompletionResponse, +} + +impl JsonModelError { + fn new(message: String, partial_response: CompletionResponse) -> Self { + Self { + message, + partial_response, + } + } +} + +impl ErrorToResponse for JsonModelError {} + +impl IntoResponse for CompletionResponder { + fn into_response(self) -> axum::response::Response { + match self { + CompletionResponder::Json(s) => Json(s).into_response(), + CompletionResponder::InternalError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) + } + CompletionResponder::ValidationError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY) + } + CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response) + .to_response(http::StatusCode::INTERNAL_SERVER_ERROR), + } + } +} + +fn parse_request( + oairequest: CompletionRequest, + state: Arc, + tx: Sender, +) -> Request { + let repr = serde_json::to_string(&oairequest).unwrap(); + MistralRs::maybe_log_request(state.clone(), repr); + + let stop_toks = match oairequest.stop_seqs { + Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), + Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), + Some(StopTokens::MultiId(m)) => Some(InternalStopTokens::Ids(m)), + Some(StopTokens::SingleId(s)) => Some(InternalStopTokens::Ids(vec![s])), + None => None, + }; + + if oairequest.logprobs.is_some() { + warn!("Completion requests do not support logprobs."); + } + + if oairequest._stream.is_some_and(|x| x) { + warn!("Completion requests do not support streaming."); + } + + Request { + id: state.next_request_id(), + messages: Either::Right(oairequest.prompt), + sampling_params: SamplingParams { + temperature: oairequest.temperature, + top_k: oairequest.top_k, + top_p: oairequest.top_p, + top_n_logprobs: 1, + frequency_penalty: oairequest.frequency_penalty, + presence_penalty: oairequest.presence_penalty, + max_len: oairequest.max_tokens, + stop_toks, + logits_bias: oairequest.logit_bias, + n_choices: oairequest.n_choices, + }, + response: tx, + return_logprobs: false, + is_streaming: false, + suffix: oairequest.suffix, + best_of: Some(oairequest.best_of), + + constraint: match oairequest.grammar { + Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), + Some(Grammar::Regex(regex)) => Constraint::Regex(regex), + None => Constraint::None, + }, + + request_type: RequestType::Completion { + echo_prompt: oairequest.echo_prompt, + }, + } +} + +#[utoipa::path( + post, + tag = "Mistral.rs", + path = "/v1/completions", + request_body = CompletionRequest, + responses((status = 200, description = "Completions")) +)] +pub async fn completions( + State(state): State>, + Json(oairequest): Json, +) -> CompletionResponder { + let (tx, rx) = channel(); + let request = parse_request(oairequest, state.clone(), tx); + let is_streaming = request.is_streaming; + let sender = state.get_sender(); + + if request.return_logprobs { + return CompletionResponder::ValidationError( + "Completion requests do not support logprobs.".into(), + ); + } + + if is_streaming { + return CompletionResponder::ValidationError( + "Completion requests do not support streaming.".into(), + ); + } + + sender.send(request).unwrap(); + + let response = rx.recv().unwrap(); + + match response { + Response::InternalError(e) => { + MistralRs::maybe_log_error(state, &*e); + CompletionResponder::InternalError(e) + } + Response::CompletionModelError(msg, response) => { + MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); + MistralRs::maybe_log_response(state, &response); + CompletionResponder::ModelError(msg, response) + } + Response::ValidationError(e) => CompletionResponder::ValidationError(e), + Response::CompletionDone(response) => { + MistralRs::maybe_log_response(state, &response); + CompletionResponder::Json(response) + } + Response::Chunk(_) => unreachable!(), + Response::Done(_) => unreachable!(), + Response::ModelError(_, _) => unreachable!(), + } +} diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index aa3363acc..82553f6c2 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -5,7 +5,7 @@ use std::{ use either::Either; use indexmap::IndexMap; -use mistralrs_core::{Constraint, MistralRs, Request, Response, SamplingParams}; +use mistralrs_core::{Constraint, MistralRs, Request, RequestType, Response, SamplingParams}; use tracing::{error, info}; pub fn interactive_mode(mistralrs: Arc) { @@ -46,6 +46,9 @@ pub fn interactive_mode(mistralrs: Arc) { return_logprobs: false, is_streaming: true, constraint: Constraint::None, + request_type: RequestType::Chat, + suffix: None, + best_of: None, }; sender.send(req).unwrap(); @@ -80,6 +83,8 @@ pub fn interactive_mode(mistralrs: Arc) { break 'outer; } Response::Done(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), } } } diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 94ef6a4bd..f31e05355 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -1,42 +1,26 @@ -use std::{ - env, - error::Error, - fs::File, - pin::Pin, - sync::{ - mpsc::{channel, Receiver, Sender}, - Arc, - }, - task::{Context, Poll}, - time::Duration, -}; +use std::{fs::File, sync::Arc}; use anyhow::Result; use axum::{ extract::{Json, State}, - http::{self, Method, StatusCode}, - response::{ - sse::{Event, KeepAlive}, - IntoResponse, Sse, - }, + http::{self, Method}, routing::{get, post}, Router, }; use candle_core::Device; use clap::Parser; -use either::Either; -use indexmap::IndexMap; use mistralrs_core::{ - ChatCompletionResponse, Constraint, GemmaLoader, GemmaSpecificConfig, LlamaLoader, - LlamaSpecificConfig, Loader, MistralLoader, MistralRs, MistralSpecificConfig, MixtralLoader, - MixtralSpecificConfig, ModelKind, Phi2Loader, Phi2SpecificConfig, Request, Response, - SamplingParams, SchedulerMethod, StopTokens as InternalStopTokens, TokenSource, + GemmaLoader, GemmaSpecificConfig, LlamaLoader, LlamaSpecificConfig, Loader, MistralLoader, + MistralRs, MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, ModelKind, Phi2Loader, + Phi2SpecificConfig, SchedulerMethod, TokenSource, }; use model_selected::ModelSelected; -use openai::{ChatCompletionGrammar, ChatCompletionRequest, Message, ModelObjects, StopTokens}; -use serde::Serialize; +use openai::{ChatCompletionRequest, Message, ModelObjects, StopTokens}; +mod chat_completion; +mod completions; +use crate::{chat_completion::__path_chatcompletions, completions::completions}; -use crate::openai::ModelObject; +use crate::{chat_completion::chatcompletions, openai::ModelObject}; mod interactive_mode; mod model_selected; mod openai; @@ -104,235 +88,6 @@ struct Args { prefix_cache_n: usize, } -#[derive(Debug)] -struct ModelErrorMessage(String); -impl std::fmt::Display for ModelErrorMessage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} -impl std::error::Error for ModelErrorMessage {} -struct Streamer { - rx: Receiver, - is_done: bool, - state: Arc, -} - -impl futures::Stream for Streamer { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if self.is_done { - return Poll::Ready(None); - } - match self.rx.try_recv() { - Ok(resp) => match resp { - Response::ModelError(msg, _) => { - MistralRs::maybe_log_error( - self.state.clone(), - &ModelErrorMessage(msg.to_string()), - ); - Poll::Ready(Some(Ok(Event::default().data(msg)))) - } - Response::ValidationError(e) => { - Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) - } - Response::InternalError(e) => { - MistralRs::maybe_log_error(self.state.clone(), &*e); - Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) - } - Response::Chunk(response) => { - if response.choices.iter().all(|x| x.stopreason.is_some()) { - self.is_done = true; - } - MistralRs::maybe_log_response(self.state.clone(), &response); - Poll::Ready(Some(Event::default().json_data(response))) - } - Response::Done(_) => unreachable!(), - }, - Err(_) => Poll::Pending, - } - } -} - -enum ChatCompletionResponder { - Sse(Sse), - Json(ChatCompletionResponse), - ModelError(String, ChatCompletionResponse), - InternalError(Box), - ValidationError(Box), -} - -trait ErrorToResponse: Serialize { - fn to_response(&self, code: StatusCode) -> axum::response::Response { - let mut r = Json(self).into_response(); - *r.status_mut() = code; - r - } -} - -#[derive(Serialize)] -struct JsonError { - message: String, -} - -impl JsonError { - fn new(message: String) -> Self { - Self { message } - } -} -impl ErrorToResponse for JsonError {} - -#[derive(Serialize)] -struct JsonModelError { - message: String, - partial_response: ChatCompletionResponse, -} - -impl JsonModelError { - fn new(message: String, partial_response: ChatCompletionResponse) -> Self { - Self { - message, - partial_response, - } - } -} - -impl ErrorToResponse for JsonModelError {} - -impl IntoResponse for ChatCompletionResponder { - fn into_response(self) -> axum::response::Response { - match self { - ChatCompletionResponder::Sse(s) => s.into_response(), - ChatCompletionResponder::Json(s) => Json(s).into_response(), - ChatCompletionResponder::InternalError(e) => { - JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) - } - ChatCompletionResponder::ValidationError(e) => { - JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY) - } - ChatCompletionResponder::ModelError(msg, response) => { - JsonModelError::new(msg, response) - .to_response(http::StatusCode::INTERNAL_SERVER_ERROR) - } - } - } -} - -fn parse_request( - oairequest: ChatCompletionRequest, - state: Arc, - tx: Sender, -) -> Request { - let repr = serde_json::to_string(&oairequest).unwrap(); - MistralRs::maybe_log_request(state.clone(), repr); - - let stop_toks = match oairequest.stop_seqs { - Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), - Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), - Some(StopTokens::MultiId(m)) => Some(InternalStopTokens::Ids(m)), - Some(StopTokens::SingleId(s)) => Some(InternalStopTokens::Ids(vec![s])), - None => None, - }; - let messages = match oairequest.messages { - Either::Left(req_messages) => { - let mut messages = Vec::new(); - for message in req_messages { - let mut message_map = IndexMap::new(); - message_map.insert("role".to_string(), message.role); - message_map.insert("content".to_string(), message.content); - messages.push(message_map); - } - Either::Left(messages) - } - Either::Right(prompt) => Either::Right(prompt), - }; - - Request { - id: state.next_request_id(), - messages, - sampling_params: SamplingParams { - temperature: oairequest.temperature, - top_k: oairequest.top_k, - top_p: oairequest.top_p, - top_n_logprobs: oairequest.top_logprobs.unwrap_or(1), - frequency_penalty: oairequest.repetition_penalty, - presence_penalty: oairequest.presence_penalty, - max_len: oairequest.max_tokens, - stop_toks, - logits_bias: oairequest.logit_bias, - n_choices: oairequest.n_choices, - }, - response: tx, - return_logprobs: oairequest.logprobs, - is_streaming: oairequest.stream.unwrap_or(false), - - constraint: match oairequest.grammar { - Some(ChatCompletionGrammar::Yacc(yacc)) => Constraint::Yacc(yacc), - Some(ChatCompletionGrammar::Regex(regex)) => Constraint::Regex(regex), - None => Constraint::None, - }, - } -} - -#[utoipa::path( - post, - tag = "Mistral.rs", - path = "/v1/chat/completions", - request_body = ChatCompletionRequest, - responses((status = 200, description = "Chat completions")) -)] -async fn chatcompletions( - State(state): State>, - Json(oairequest): Json, -) -> ChatCompletionResponder { - let (tx, rx) = channel(); - let request = parse_request(oairequest, state.clone(), tx); - let is_streaming = request.is_streaming; - let sender = state.get_sender(); - sender.send(request).unwrap(); - - if is_streaming { - let streamer = Streamer { - rx, - is_done: false, - state, - }; - - ChatCompletionResponder::Sse( - Sse::new(streamer).keep_alive( - KeepAlive::new() - .interval(Duration::from_millis( - env::var("KEEP_ALIVE_INTERVAL") - .map(|val| val.parse::().unwrap_or(1000)) - .unwrap_or(1000), - )) - .text("keep-alive-text"), - ), - ) - } else { - let response = rx.recv().unwrap(); - - match response { - Response::InternalError(e) => { - MistralRs::maybe_log_error(state, &*e); - ChatCompletionResponder::InternalError(e) - } - Response::ModelError(msg, response) => { - MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); - MistralRs::maybe_log_response(state, &response); - ChatCompletionResponder::ModelError(msg, response) - } - Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e), - Response::Done(response) => { - MistralRs::maybe_log_response(state, &response); - ChatCompletionResponder::Json(response) - } - Response::Chunk(_) => unreachable!(), - } - } -} - #[utoipa::path( get, tag = "Mistral.rs", @@ -391,6 +146,7 @@ fn get_router(state: Arc) -> Router { .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc)) .layer(cors_layer) .route("/v1/chat/completions", post(chatcompletions)) + .route("/v1/completions", post(completions)) .route("/v1/models", get(models)) .route("/health", get(health)) .route("/", get(health)) diff --git a/mistralrs-server/src/openai.rs b/mistralrs-server/src/openai.rs index a9308afe7..14bef215c 100644 --- a/mistralrs-server/src/openai.rs +++ b/mistralrs-server/src/openai.rs @@ -29,7 +29,7 @@ fn default_1usize() -> usize { #[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] #[serde(tag = "type", content = "value")] -pub enum ChatCompletionGrammar { +pub enum Grammar { #[serde(rename = "regex")] Regex(String), #[serde(rename = "yacc")] @@ -58,9 +58,8 @@ pub struct ChatCompletionRequest { pub n_choices: usize, #[schema(example = json!(Option::None::))] pub presence_penalty: Option, - #[serde(rename = "frequency_penalty")] #[schema(example = json!(Option::None::))] - pub repetition_penalty: Option, + pub frequency_penalty: Option, #[serde(rename = "stop")] #[schema(example = json!(Option::None::))] pub stop_seqs: Option, @@ -75,8 +74,8 @@ pub struct ChatCompletionRequest { #[schema(example = json!(Option::None::))] pub top_k: Option, - #[schema(example = json!(Option::None::))] - pub grammar: Option, + #[schema(example = json!(Option::None::))] + pub grammar: Option, } #[derive(Debug, Serialize, ToSchema)] @@ -92,3 +91,52 @@ pub struct ModelObjects { pub object: &'static str, pub data: Vec, } + +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +pub struct CompletionRequest { + #[schema(example = "mistral")] + pub model: String, + #[schema(example = "Say this is a test.")] + pub prompt: String, + #[serde(default = "default_1usize")] + #[schema(example = 1)] + pub best_of: usize, + #[serde(rename = "echo")] + #[serde(default = "default_false")] + #[schema(example = false)] + pub echo_prompt: bool, + #[schema(example = json!(Option::None::))] + pub presence_penalty: Option, + #[schema(example = json!(Option::None::))] + pub frequency_penalty: Option, + #[schema(example = json!(Option::None::>))] + pub logit_bias: Option>, + #[schema(example = json!(Option::None::))] + pub logprobs: Option, + #[schema(example = 16)] + pub max_tokens: Option, + #[serde(rename = "n")] + #[serde(default = "default_1usize")] + #[schema(example = 1)] + pub n_choices: usize, + #[serde(rename = "stop")] + #[schema(example = json!(Option::None::))] + pub stop_seqs: Option, + #[serde(rename = "stream")] + pub _stream: Option, + #[schema(example = 0.7)] + pub temperature: Option, + #[schema(example = json!(Option::None::))] + pub top_p: Option, + #[schema(example = json!(Option::None::))] + pub suffix: Option, + #[serde(rename = "user")] + pub _user: Option, + + // mistral.rs additional + #[schema(example = json!(Option::None::))] + pub top_k: Option, + + #[schema(example = json!(Option::None::))] + pub grammar: Option, +} diff --git a/mistralrs/src/lib.rs b/mistralrs/src/lib.rs index ede8ce822..bcdcce3e6 100644 --- a/mistralrs/src/lib.rs +++ b/mistralrs/src/lib.rs @@ -1,6 +1,6 @@ pub use mistralrs_core::{ Constraint, GemmaLoader, GemmaSpecificConfig, LlamaLoader, LlamaSpecificConfig, Loader, MistralLoader, MistralRs, MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, - ModelKind, Ordering, Phi2Loader, Phi2SpecificConfig, Pipeline, Request, Response, + ModelKind, Ordering, Phi2Loader, Phi2SpecificConfig, Pipeline, Request, RequestType, Response, SamplingParams, SchedulerMethod, StopTokens, TokenSource, };