From b4bccd0449b8982499fc3be65804fc97f5ec806f Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:00:23 -0400 Subject: [PATCH 01/15] Initial impl of api --- examples/http.md | 6 +- integrations/llama_index_integration.py | 4 +- mistralrs-core/src/lib.rs | 4 +- mistralrs-core/src/request.rs | 11 +- mistralrs-core/src/response.rs | 48 +++- mistralrs-core/src/sequence.rs | 6 +- mistralrs-pyo3/API.md | 4 +- mistralrs-pyo3/README.md | 2 +- mistralrs-pyo3/mistralrs.pyi | 2 +- mistralrs-pyo3/src/lib.rs | 18 +- mistralrs-server/src/interactive_mode.rs | 6 +- mistralrs-server/src/main.rs | 266 +---------------------- mistralrs-server/src/openai.rs | 58 ++++- mistralrs/src/lib.rs | 2 +- 14 files changed, 150 insertions(+), 287 deletions(-) diff --git a/examples/http.md b/examples/http.md index de78961b8..433654385 100644 --- a/examples/http.md +++ b/examples/http.md @@ -134,7 +134,7 @@ pub struct ChatCompletionResponse { pub model: &'static str, pub system_fingerprint: String, pub object: String, - pub usage: ChatCompletionUsage, + pub usage: Usage, } ``` @@ -186,9 +186,9 @@ pub struct TopLogprob { } ``` -### `ChatCompletionUsage` +### `Usage` ```rust -pub struct ChatCompletionUsage { +pub struct Usage { pub completion_tokens: usize, pub prompt_tokens: usize, pub total_tokens: usize, diff --git a/integrations/llama_index_integration.py b/integrations/llama_index_integration.py index e12dd6b98..c3a52969e 100644 --- a/integrations/llama_index_integration.py +++ b/integrations/llama_index_integration.py @@ -365,7 +365,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: top_k=self.generate_kwargs["top_k"], top_p=self.generate_kwargs["top_p"], presence_penalty=self.generate_kwargs.get("presence_penalty", None), - repetition_penalty=self.generate_kwargs.get("repetition_penalty", None), + frequency_penalty=self.generate_kwargs.get("frequency_penalty", None), temperature=self.generate_kwargs.get("temperature", None), ) completion_response = self._runner.send_chat_completion_request(request) @@ -399,7 +399,7 @@ def complete( top_k=self.generate_kwargs["top_k"], top_p=self.generate_kwargs["top_p"], presence_penalty=self.generate_kwargs.get("presence_penalty", None), - repetition_penalty=self.generate_kwargs.get("repetition_penalty", None), + frequency_penalty=self.generate_kwargs.get("frequency_penalty", None), temperature=self.generate_kwargs.get("temperature", None), ) completion_response = self._runner.send_chat_completion_request(request) diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index a20d6bd3a..9ba27dce4 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -35,9 +35,9 @@ pub use pipeline::{ MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, ModelKind, Phi2Loader, Phi2SpecificConfig, TokenSource, }; -pub use request::{Constraint, Request}; +pub use request::{Constraint, Request, RequestType}; pub use response::Response; -pub use response::{ChatCompletionResponse, ChatCompletionUsage}; +pub use response::{ChatCompletionResponse, CompletionResponse, Usage}; pub use sampler::{SamplingParams, StopTokens}; pub use scheduler::SchedulerMethod; use serde::Serialize; diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 0a9654640..fe21c1a80 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -10,6 +10,12 @@ pub enum Constraint { None, } +#[derive(Debug)] +pub enum RequestType { + Chat, + Completion, +} + pub struct Request { pub messages: Either>, String>, pub sampling_params: SamplingParams, @@ -18,14 +24,15 @@ pub struct Request { pub is_streaming: bool, pub id: usize, pub constraint: Constraint, + pub request_type: RequestType, } impl Debug for Request { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Request {} {{ messages: `{:?}`, sampling_params: {:?}}}", - self.id, self.messages, self.sampling_params + "Request {} ({:?}) {{ messages: `{:?}`, sampling_params: {:?}}}", + self.id, self.request_type, self.messages, self.sampling_params ) } } diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index 07355cb5d..be6edb76c 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -50,7 +50,7 @@ pub struct ChunkChoice { } #[derive(Debug, Clone, Serialize)] -pub struct ChatCompletionUsage { +pub struct Usage { pub completion_tokens: usize, pub prompt_tokens: usize, pub total_tokens: usize, @@ -72,7 +72,7 @@ pub struct ChatCompletionResponse { pub model: String, pub system_fingerprint: String, pub object: String, - pub usage: ChatCompletionUsage, + pub usage: Usage, } #[derive(Debug, Clone, Serialize)] @@ -85,10 +85,54 @@ pub struct ChatCompletionChunkResponse { pub object: String, } +#[derive(Debug, Clone, Serialize)] +pub struct CompletionResponseLogprob { + pub token: String, + pub logprob: f32, + pub bytes: Vec, + pub top_logprobs: Vec, +} + +#[derive(Debug, Clone, Serialize)] +pub struct CompletionLogprobs { + pub content: Option>, +} + +#[derive(Debug, Clone, Serialize)] +pub struct CompletionChoice { + #[serde(rename = "finish_reason")] + pub stopreason: String, + pub index: usize, + pub text: String, + pub logprobs: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub choices: Vec, + pub created: u64, + pub model: String, + pub system_fingerprint: String, + pub object: String, + pub usage: Usage, +} + +#[derive(Debug, Clone, Serialize)] +pub struct CompletionChunkResponse { + pub data: String, + pub done: bool, +} + pub enum Response { InternalError(Box), ValidationError(Box), + // Chat specific ModelError(String, ChatCompletionResponse), Done(ChatCompletionResponse), Chunk(ChatCompletionChunkResponse), + // Completion specific + CompletionModelError(String, CompletionResponse), + CompletionDone(CompletionResponse), + CompletionChunk(CompletionChunkResponse), } diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 34728dec1..1d7e6e403 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -11,7 +11,7 @@ use crate::{ models::LayerCaches, response::{ChatCompletionChunkResponse, Choice, ChunkChoice, Response, SYSTEM_FINGERPRINT}, sampler::{Logprobs, Sampler}, - ChatCompletionResponse, ChatCompletionUsage, + ChatCompletionResponse, Usage, }; use candle_core::Tensor; use regex_automata::util::primitives::StateID; @@ -370,9 +370,9 @@ impl SequenceGroup { &self.choices } - pub fn get_usage(&self) -> ChatCompletionUsage { + pub fn get_usage(&self) -> Usage { #[allow(clippy::cast_precision_loss)] - ChatCompletionUsage { + Usage { completion_tokens: self.total_toks - self.total_prompt_toks, prompt_tokens: self.total_prompt_toks, total_tokens: self.total_toks, diff --git a/mistralrs-pyo3/API.md b/mistralrs-pyo3/API.md index 25522aac3..e58e97b02 100644 --- a/mistralrs-pyo3/API.md +++ b/mistralrs-pyo3/API.md @@ -95,13 +95,13 @@ Request is a class with a constructor which accepts the following arguments. It - `max_tokens: usize | None` - `n_choices: usize` - `presence_penalty: float | None` -- `repetition_penalty: float | None` +- `frequency_penalty: float | None` - `stop_token_ids: list[int] | None` - `temperature: float | None` - `top_p: float | None` - `top_k: usize | None` -`ChatCompletionRequest(messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, repetition_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None)` +`ChatCompletionRequest(messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, frequency_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None)` ## `ModelKind` - Normal diff --git a/mistralrs-pyo3/README.md b/mistralrs-pyo3/README.md index 9a8a722f8..c35b020b1 100644 --- a/mistralrs-pyo3/README.md +++ b/mistralrs-pyo3/README.md @@ -45,7 +45,7 @@ res = runner.send_chat_completion_request( {"role": "user", "content": "Tell me a story about the Rust type system."} ], max_tokens=256, - repetition_penalty=1.0, + frequency_penalty=1.0, top_p=0.1, temperature=0.1, ) diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index ea26ffe2a..852fa60ce 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -16,7 +16,7 @@ class ChatCompletionRequest: max_tokens: int | None = None n_choices: int = 1 presence_penalty: float | None = None - repetition_penalty: float | None = None + frequency_penalty: float | None = None stop_token_ids: list[int] | None = None temperature: float | None = None top_p: float | None = None diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index d0a0bef91..1dab93a7b 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -9,7 +9,7 @@ use std::{ }; use ::mistralrs::{ - Constraint, MistralRs, Request as _Request, Response, SamplingParams, StopTokens, + Constraint, MistralRs, Request as _Request, RequestType, Response, SamplingParams, StopTokens, }; use candle_core::Device; use loaders::{ @@ -135,7 +135,7 @@ impl Runner { top_k: request.top_k, top_p: request.top_p, top_n_logprobs: request.top_logprobs.unwrap_or(1), - frequency_penalty: request.repetition_penalty, + frequency_penalty: request.frequency_penalty, presence_penalty: request.presence_penalty, max_len: request.max_tokens, stop_toks, @@ -146,6 +146,7 @@ impl Runner { return_logprobs: request.logprobs, is_streaming: request.stream, constraint, + request_type: RequestType::Chat, }; MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); @@ -161,8 +162,11 @@ impl Runner { MistralRs::maybe_log_response(self.runner.clone(), &response); Ok(serde_json::to_string(&response).unwrap()) } - Response::Chunk(_) => unreachable!(), Response::ModelError(msg, _) => Err(PyValueError::new_err(msg.to_string())), + Response::Chunk(_) => unreachable!(), + Response::CompletionChunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), } }) } @@ -180,7 +184,7 @@ struct ChatCompletionRequest { max_tokens: Option, n_choices: usize, presence_penalty: Option, - repetition_penalty: Option, + frequency_penalty: Option, stop_token_ids: Option>, temperature: Option, top_p: Option, @@ -193,7 +197,7 @@ struct ChatCompletionRequest { #[pymethods] impl ChatCompletionRequest { #[new] - #[pyo3(signature = (messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, repetition_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None, stream=false, grammar = None, grammar_type = None))] + #[pyo3(signature = (messages, model, logprobs = false, n_choices = 1, logit_bias = None, top_logprobs = None, max_tokens = None, presence_penalty = None, frequency_penalty = None, stop_token_ids = None, temperature = None, top_p = None, top_k = None, stream=false, grammar = None, grammar_type = None))] #[allow(clippy::too_many_arguments)] fn new( messages: Py, @@ -204,7 +208,7 @@ impl ChatCompletionRequest { top_logprobs: Option, max_tokens: Option, presence_penalty: Option, - repetition_penalty: Option, + frequency_penalty: Option, stop_token_ids: Option>, temperature: Option, top_p: Option, @@ -253,7 +257,7 @@ impl ChatCompletionRequest { max_tokens, n_choices, presence_penalty, - repetition_penalty, + frequency_penalty, stop_token_ids, temperature, top_p, diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index aa3363acc..833b33144 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -5,7 +5,7 @@ use std::{ use either::Either; use indexmap::IndexMap; -use mistralrs_core::{Constraint, MistralRs, Request, Response, SamplingParams}; +use mistralrs_core::{Constraint, MistralRs, Request, RequestType, Response, SamplingParams}; use tracing::{error, info}; pub fn interactive_mode(mistralrs: Arc) { @@ -46,6 +46,7 @@ pub fn interactive_mode(mistralrs: Arc) { return_logprobs: false, is_streaming: true, constraint: Constraint::None, + request_type: RequestType::Chat, }; sender.send(req).unwrap(); @@ -80,6 +81,9 @@ pub fn interactive_mode(mistralrs: Arc) { break 'outer; } Response::Done(_) => unreachable!(), + Response::CompletionChunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), } } } diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 94ef6a4bd..f31e05355 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -1,42 +1,26 @@ -use std::{ - env, - error::Error, - fs::File, - pin::Pin, - sync::{ - mpsc::{channel, Receiver, Sender}, - Arc, - }, - task::{Context, Poll}, - time::Duration, -}; +use std::{fs::File, sync::Arc}; use anyhow::Result; use axum::{ extract::{Json, State}, - http::{self, Method, StatusCode}, - response::{ - sse::{Event, KeepAlive}, - IntoResponse, Sse, - }, + http::{self, Method}, routing::{get, post}, Router, }; use candle_core::Device; use clap::Parser; -use either::Either; -use indexmap::IndexMap; use mistralrs_core::{ - ChatCompletionResponse, Constraint, GemmaLoader, GemmaSpecificConfig, LlamaLoader, - LlamaSpecificConfig, Loader, MistralLoader, MistralRs, MistralSpecificConfig, MixtralLoader, - MixtralSpecificConfig, ModelKind, Phi2Loader, Phi2SpecificConfig, Request, Response, - SamplingParams, SchedulerMethod, StopTokens as InternalStopTokens, TokenSource, + GemmaLoader, GemmaSpecificConfig, LlamaLoader, LlamaSpecificConfig, Loader, MistralLoader, + MistralRs, MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, ModelKind, Phi2Loader, + Phi2SpecificConfig, SchedulerMethod, TokenSource, }; use model_selected::ModelSelected; -use openai::{ChatCompletionGrammar, ChatCompletionRequest, Message, ModelObjects, StopTokens}; -use serde::Serialize; +use openai::{ChatCompletionRequest, Message, ModelObjects, StopTokens}; +mod chat_completion; +mod completions; +use crate::{chat_completion::__path_chatcompletions, completions::completions}; -use crate::openai::ModelObject; +use crate::{chat_completion::chatcompletions, openai::ModelObject}; mod interactive_mode; mod model_selected; mod openai; @@ -104,235 +88,6 @@ struct Args { prefix_cache_n: usize, } -#[derive(Debug)] -struct ModelErrorMessage(String); -impl std::fmt::Display for ModelErrorMessage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} -impl std::error::Error for ModelErrorMessage {} -struct Streamer { - rx: Receiver, - is_done: bool, - state: Arc, -} - -impl futures::Stream for Streamer { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if self.is_done { - return Poll::Ready(None); - } - match self.rx.try_recv() { - Ok(resp) => match resp { - Response::ModelError(msg, _) => { - MistralRs::maybe_log_error( - self.state.clone(), - &ModelErrorMessage(msg.to_string()), - ); - Poll::Ready(Some(Ok(Event::default().data(msg)))) - } - Response::ValidationError(e) => { - Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) - } - Response::InternalError(e) => { - MistralRs::maybe_log_error(self.state.clone(), &*e); - Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) - } - Response::Chunk(response) => { - if response.choices.iter().all(|x| x.stopreason.is_some()) { - self.is_done = true; - } - MistralRs::maybe_log_response(self.state.clone(), &response); - Poll::Ready(Some(Event::default().json_data(response))) - } - Response::Done(_) => unreachable!(), - }, - Err(_) => Poll::Pending, - } - } -} - -enum ChatCompletionResponder { - Sse(Sse), - Json(ChatCompletionResponse), - ModelError(String, ChatCompletionResponse), - InternalError(Box), - ValidationError(Box), -} - -trait ErrorToResponse: Serialize { - fn to_response(&self, code: StatusCode) -> axum::response::Response { - let mut r = Json(self).into_response(); - *r.status_mut() = code; - r - } -} - -#[derive(Serialize)] -struct JsonError { - message: String, -} - -impl JsonError { - fn new(message: String) -> Self { - Self { message } - } -} -impl ErrorToResponse for JsonError {} - -#[derive(Serialize)] -struct JsonModelError { - message: String, - partial_response: ChatCompletionResponse, -} - -impl JsonModelError { - fn new(message: String, partial_response: ChatCompletionResponse) -> Self { - Self { - message, - partial_response, - } - } -} - -impl ErrorToResponse for JsonModelError {} - -impl IntoResponse for ChatCompletionResponder { - fn into_response(self) -> axum::response::Response { - match self { - ChatCompletionResponder::Sse(s) => s.into_response(), - ChatCompletionResponder::Json(s) => Json(s).into_response(), - ChatCompletionResponder::InternalError(e) => { - JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) - } - ChatCompletionResponder::ValidationError(e) => { - JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY) - } - ChatCompletionResponder::ModelError(msg, response) => { - JsonModelError::new(msg, response) - .to_response(http::StatusCode::INTERNAL_SERVER_ERROR) - } - } - } -} - -fn parse_request( - oairequest: ChatCompletionRequest, - state: Arc, - tx: Sender, -) -> Request { - let repr = serde_json::to_string(&oairequest).unwrap(); - MistralRs::maybe_log_request(state.clone(), repr); - - let stop_toks = match oairequest.stop_seqs { - Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), - Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), - Some(StopTokens::MultiId(m)) => Some(InternalStopTokens::Ids(m)), - Some(StopTokens::SingleId(s)) => Some(InternalStopTokens::Ids(vec![s])), - None => None, - }; - let messages = match oairequest.messages { - Either::Left(req_messages) => { - let mut messages = Vec::new(); - for message in req_messages { - let mut message_map = IndexMap::new(); - message_map.insert("role".to_string(), message.role); - message_map.insert("content".to_string(), message.content); - messages.push(message_map); - } - Either::Left(messages) - } - Either::Right(prompt) => Either::Right(prompt), - }; - - Request { - id: state.next_request_id(), - messages, - sampling_params: SamplingParams { - temperature: oairequest.temperature, - top_k: oairequest.top_k, - top_p: oairequest.top_p, - top_n_logprobs: oairequest.top_logprobs.unwrap_or(1), - frequency_penalty: oairequest.repetition_penalty, - presence_penalty: oairequest.presence_penalty, - max_len: oairequest.max_tokens, - stop_toks, - logits_bias: oairequest.logit_bias, - n_choices: oairequest.n_choices, - }, - response: tx, - return_logprobs: oairequest.logprobs, - is_streaming: oairequest.stream.unwrap_or(false), - - constraint: match oairequest.grammar { - Some(ChatCompletionGrammar::Yacc(yacc)) => Constraint::Yacc(yacc), - Some(ChatCompletionGrammar::Regex(regex)) => Constraint::Regex(regex), - None => Constraint::None, - }, - } -} - -#[utoipa::path( - post, - tag = "Mistral.rs", - path = "/v1/chat/completions", - request_body = ChatCompletionRequest, - responses((status = 200, description = "Chat completions")) -)] -async fn chatcompletions( - State(state): State>, - Json(oairequest): Json, -) -> ChatCompletionResponder { - let (tx, rx) = channel(); - let request = parse_request(oairequest, state.clone(), tx); - let is_streaming = request.is_streaming; - let sender = state.get_sender(); - sender.send(request).unwrap(); - - if is_streaming { - let streamer = Streamer { - rx, - is_done: false, - state, - }; - - ChatCompletionResponder::Sse( - Sse::new(streamer).keep_alive( - KeepAlive::new() - .interval(Duration::from_millis( - env::var("KEEP_ALIVE_INTERVAL") - .map(|val| val.parse::().unwrap_or(1000)) - .unwrap_or(1000), - )) - .text("keep-alive-text"), - ), - ) - } else { - let response = rx.recv().unwrap(); - - match response { - Response::InternalError(e) => { - MistralRs::maybe_log_error(state, &*e); - ChatCompletionResponder::InternalError(e) - } - Response::ModelError(msg, response) => { - MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); - MistralRs::maybe_log_response(state, &response); - ChatCompletionResponder::ModelError(msg, response) - } - Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e), - Response::Done(response) => { - MistralRs::maybe_log_response(state, &response); - ChatCompletionResponder::Json(response) - } - Response::Chunk(_) => unreachable!(), - } - } -} - #[utoipa::path( get, tag = "Mistral.rs", @@ -391,6 +146,7 @@ fn get_router(state: Arc) -> Router { .merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc)) .layer(cors_layer) .route("/v1/chat/completions", post(chatcompletions)) + .route("/v1/completions", post(completions)) .route("/v1/models", get(models)) .route("/health", get(health)) .route("/", get(health)) diff --git a/mistralrs-server/src/openai.rs b/mistralrs-server/src/openai.rs index a9308afe7..60c467f9f 100644 --- a/mistralrs-server/src/openai.rs +++ b/mistralrs-server/src/openai.rs @@ -29,7 +29,7 @@ fn default_1usize() -> usize { #[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] #[serde(tag = "type", content = "value")] -pub enum ChatCompletionGrammar { +pub enum Grammar { #[serde(rename = "regex")] Regex(String), #[serde(rename = "yacc")] @@ -58,9 +58,8 @@ pub struct ChatCompletionRequest { pub n_choices: usize, #[schema(example = json!(Option::None::))] pub presence_penalty: Option, - #[serde(rename = "frequency_penalty")] #[schema(example = json!(Option::None::))] - pub repetition_penalty: Option, + pub frequency_penalty: Option, #[serde(rename = "stop")] #[schema(example = json!(Option::None::))] pub stop_seqs: Option, @@ -75,8 +74,8 @@ pub struct ChatCompletionRequest { #[schema(example = json!(Option::None::))] pub top_k: Option, - #[schema(example = json!(Option::None::))] - pub grammar: Option, + #[schema(example = json!(Option::None::))] + pub grammar: Option, } #[derive(Debug, Serialize, ToSchema)] @@ -92,3 +91,52 @@ pub struct ModelObjects { pub object: &'static str, pub data: Vec, } + +#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)] +pub struct CompletionRequest { + #[schema(example = "mistral")] + pub model: String, + #[schema(example = "Say this is a test.")] + pub prompt: String, + #[serde(default = "default_1usize")] + #[schema(example = 1)] + pub best_of: usize, + #[serde(rename = "echo")] + #[serde(default = "default_false")] + #[schema(example = false)] + pub echo_prompt: bool, + #[schema(example = json!(Option::None::))] + pub presence_penalty: Option, + #[schema(example = json!(Option::None::))] + pub frequency_penalty: Option, + #[schema(example = json!(Option::None::>))] + pub logit_bias: Option>, + #[schema(example = json!(Option::None::))] + pub logprobs: Option, + #[schema(example = 16)] + pub max_tokens: Option, + #[serde(rename = "n")] + #[serde(default = "default_1usize")] + #[schema(example = 1)] + pub n_choices: usize, + #[serde(rename = "stop")] + #[schema(example = json!(Option::None::))] + pub stop_seqs: Option, + #[schema(example = true)] + pub stream: Option, + #[schema(example = 0.7)] + pub temperature: Option, + #[schema(example = json!(Option::None::))] + pub top_p: Option, + #[schema(example = json!(Option::None::))] + pub suffix: Option, + #[serde(rename = "user")] + pub _user: Option, + + // mistral.rs additional + #[schema(example = json!(Option::None::))] + pub top_k: Option, + + #[schema(example = json!(Option::None::))] + pub grammar: Option, +} diff --git a/mistralrs/src/lib.rs b/mistralrs/src/lib.rs index ede8ce822..bcdcce3e6 100644 --- a/mistralrs/src/lib.rs +++ b/mistralrs/src/lib.rs @@ -1,6 +1,6 @@ pub use mistralrs_core::{ Constraint, GemmaLoader, GemmaSpecificConfig, LlamaLoader, LlamaSpecificConfig, Loader, MistralLoader, MistralRs, MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, - ModelKind, Ordering, Phi2Loader, Phi2SpecificConfig, Pipeline, Request, Response, + ModelKind, Ordering, Phi2Loader, Phi2SpecificConfig, Pipeline, Request, RequestType, Response, SamplingParams, SchedulerMethod, StopTokens, TokenSource, }; From 5218dd133221720d71d82f4f637e449520c4574a Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:00:30 -0400 Subject: [PATCH 02/15] Initial impl of api --- mistralrs-server/src/chat_completion.rs | 266 ++++++++++++++++++++++++ mistralrs-server/src/completions.rs | 250 ++++++++++++++++++++++ 2 files changed, 516 insertions(+) create mode 100644 mistralrs-server/src/chat_completion.rs create mode 100644 mistralrs-server/src/completions.rs diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs new file mode 100644 index 000000000..761fa34d1 --- /dev/null +++ b/mistralrs-server/src/chat_completion.rs @@ -0,0 +1,266 @@ +use std::{ + env, + error::Error, + pin::Pin, + sync::{ + mpsc::{channel, Receiver, Sender}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + +use crate::openai::{ChatCompletionRequest, Grammar, StopTokens}; +use anyhow::Result; +use axum::{ + extract::{Json, State}, + http::{self, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Sse, + }, +}; +use either::Either; +use indexmap::IndexMap; +use mistralrs_core::{ + ChatCompletionResponse, Constraint, MistralRs, Request, RequestType, Response, SamplingParams, + StopTokens as InternalStopTokens, +}; +use serde::Serialize; + +#[derive(Debug)] +struct ModelErrorMessage(String); +impl std::fmt::Display for ModelErrorMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} +impl std::error::Error for ModelErrorMessage {} +pub struct Streamer { + rx: Receiver, + is_done: bool, + state: Arc, +} + +impl futures::Stream for Streamer { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_done { + return Poll::Ready(None); + } + match self.rx.try_recv() { + Ok(resp) => match resp { + Response::ModelError(msg, _) => { + MistralRs::maybe_log_error( + self.state.clone(), + &ModelErrorMessage(msg.to_string()), + ); + Poll::Ready(Some(Ok(Event::default().data(msg)))) + } + Response::ValidationError(e) => { + Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) + } + Response::InternalError(e) => { + MistralRs::maybe_log_error(self.state.clone(), &*e); + Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) + } + Response::Chunk(response) => { + if response.choices.iter().all(|x| x.stopreason.is_some()) { + self.is_done = true; + } + MistralRs::maybe_log_response(self.state.clone(), &response); + Poll::Ready(Some(Event::default().json_data(response))) + } + Response::Done(_) => unreachable!(), + Response::CompletionChunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), + }, + Err(_) => Poll::Pending, + } + } +} + +pub enum ChatCompletionResponder { + Sse(Sse), + Json(ChatCompletionResponse), + ModelError(String, ChatCompletionResponse), + InternalError(Box), + ValidationError(Box), +} + +trait ErrorToResponse: Serialize { + fn to_response(&self, code: StatusCode) -> axum::response::Response { + let mut r = Json(self).into_response(); + *r.status_mut() = code; + r + } +} + +#[derive(Serialize)] +struct JsonError { + message: String, +} + +impl JsonError { + fn new(message: String) -> Self { + Self { message } + } +} +impl ErrorToResponse for JsonError {} + +#[derive(Serialize)] +struct JsonModelError { + message: String, + partial_response: ChatCompletionResponse, +} + +impl JsonModelError { + fn new(message: String, partial_response: ChatCompletionResponse) -> Self { + Self { + message, + partial_response, + } + } +} + +impl ErrorToResponse for JsonModelError {} + +impl IntoResponse for ChatCompletionResponder { + fn into_response(self) -> axum::response::Response { + match self { + ChatCompletionResponder::Sse(s) => s.into_response(), + ChatCompletionResponder::Json(s) => Json(s).into_response(), + ChatCompletionResponder::InternalError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) + } + ChatCompletionResponder::ValidationError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY) + } + ChatCompletionResponder::ModelError(msg, response) => { + JsonModelError::new(msg, response) + .to_response(http::StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +fn parse_request( + oairequest: ChatCompletionRequest, + state: Arc, + tx: Sender, +) -> Request { + let repr = serde_json::to_string(&oairequest).unwrap(); + MistralRs::maybe_log_request(state.clone(), repr); + + let stop_toks = match oairequest.stop_seqs { + Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), + Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), + Some(StopTokens::MultiId(m)) => Some(InternalStopTokens::Ids(m)), + Some(StopTokens::SingleId(s)) => Some(InternalStopTokens::Ids(vec![s])), + None => None, + }; + let messages = match oairequest.messages { + Either::Left(req_messages) => { + let mut messages = Vec::new(); + for message in req_messages { + let mut message_map = IndexMap::new(); + message_map.insert("role".to_string(), message.role); + message_map.insert("content".to_string(), message.content); + messages.push(message_map); + } + Either::Left(messages) + } + Either::Right(prompt) => Either::Right(prompt), + }; + + Request { + id: state.next_request_id(), + messages, + sampling_params: SamplingParams { + temperature: oairequest.temperature, + top_k: oairequest.top_k, + top_p: oairequest.top_p, + top_n_logprobs: oairequest.top_logprobs.unwrap_or(1), + frequency_penalty: oairequest.frequency_penalty, + presence_penalty: oairequest.presence_penalty, + max_len: oairequest.max_tokens, + stop_toks, + logits_bias: oairequest.logit_bias, + n_choices: oairequest.n_choices, + }, + response: tx, + return_logprobs: oairequest.logprobs, + is_streaming: oairequest.stream.unwrap_or(false), + + constraint: match oairequest.grammar { + Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), + Some(Grammar::Regex(regex)) => Constraint::Regex(regex), + None => Constraint::None, + }, + + request_type: RequestType::Chat, + } +} + +#[utoipa::path( + post, + tag = "Mistral.rs", + path = "/v1/chat/completions", + request_body = ChatCompletionRequest, + responses((status = 200, description = "Chat completions")) +)] +pub async fn chatcompletions( + State(state): State>, + Json(oairequest): Json, +) -> ChatCompletionResponder { + let (tx, rx) = channel(); + let request = parse_request(oairequest, state.clone(), tx); + let is_streaming = request.is_streaming; + let sender = state.get_sender(); + sender.send(request).unwrap(); + + if is_streaming { + let streamer = Streamer { + rx, + is_done: false, + state, + }; + + ChatCompletionResponder::Sse( + Sse::new(streamer).keep_alive( + KeepAlive::new() + .interval(Duration::from_millis( + env::var("KEEP_ALIVE_INTERVAL") + .map(|val| val.parse::().unwrap_or(1000)) + .unwrap_or(1000), + )) + .text("keep-alive-text"), + ), + ) + } else { + let response = rx.recv().unwrap(); + + match response { + Response::InternalError(e) => { + MistralRs::maybe_log_error(state, &*e); + ChatCompletionResponder::InternalError(e) + } + Response::ModelError(msg, response) => { + MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); + MistralRs::maybe_log_response(state, &response); + ChatCompletionResponder::ModelError(msg, response) + } + Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e), + Response::Done(response) => { + MistralRs::maybe_log_response(state, &response); + ChatCompletionResponder::Json(response) + } + Response::Chunk(_) => unreachable!(), + Response::CompletionChunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), + } + } +} diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs new file mode 100644 index 000000000..cffc8cf7c --- /dev/null +++ b/mistralrs-server/src/completions.rs @@ -0,0 +1,250 @@ +use std::{ + env, + error::Error, + pin::Pin, + sync::{ + mpsc::{channel, Receiver, Sender}, + Arc, + }, + task::{Context, Poll}, + time::Duration, +}; + +use crate::openai::{CompletionRequest, Grammar, StopTokens}; +use anyhow::Result; +use axum::{ + extract::{Json, State}, + http::{self, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + IntoResponse, Sse, + }, +}; +use either::Either; +use mistralrs_core::{ + CompletionResponse, Constraint, MistralRs, Request, RequestType, Response, SamplingParams, + StopTokens as InternalStopTokens, +}; +use serde::Serialize; + +#[derive(Debug)] +struct ModelErrorMessage(String); +impl std::fmt::Display for ModelErrorMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} +impl std::error::Error for ModelErrorMessage {} +pub struct Streamer { + rx: Receiver, + is_done: bool, + state: Arc, +} + +impl futures::Stream for Streamer { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + if self.is_done { + return Poll::Ready(None); + } + match self.rx.try_recv() { + Ok(resp) => match resp { + Response::CompletionModelError(msg, _) => { + MistralRs::maybe_log_error( + self.state.clone(), + &ModelErrorMessage(msg.to_string()), + ); + Poll::Ready(Some(Ok(Event::default().data(msg)))) + } + Response::ValidationError(e) => { + Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) + } + Response::InternalError(e) => { + MistralRs::maybe_log_error(self.state.clone(), &*e); + Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) + } + Response::CompletionChunk(response) => { + if response.done { + self.is_done = true; + } + MistralRs::maybe_log_response(self.state.clone(), &response); + Poll::Ready(Some(Ok(Event::default().data(response.data)))) + } + Response::CompletionDone(_) => unreachable!(), + Response::Chunk(_) => unreachable!(), + Response::Done(_) => unreachable!(), + Response::ModelError(_, _) => unreachable!(), + }, + Err(_) => Poll::Pending, + } + } +} + +pub enum CompletionResponder { + Sse(Sse), + Json(CompletionResponse), + ModelError(String, CompletionResponse), + InternalError(Box), + ValidationError(Box), +} + +trait ErrorToResponse: Serialize { + fn to_response(&self, code: StatusCode) -> axum::response::Response { + let mut r = Json(self).into_response(); + *r.status_mut() = code; + r + } +} + +#[derive(Serialize)] +struct JsonError { + message: String, +} + +impl JsonError { + fn new(message: String) -> Self { + Self { message } + } +} +impl ErrorToResponse for JsonError {} + +#[derive(Serialize)] +struct JsonModelError { + message: String, + partial_response: CompletionResponse, +} + +impl JsonModelError { + fn new(message: String, partial_response: CompletionResponse) -> Self { + Self { + message, + partial_response, + } + } +} + +impl ErrorToResponse for JsonModelError {} + +impl IntoResponse for CompletionResponder { + fn into_response(self) -> axum::response::Response { + match self { + CompletionResponder::Sse(s) => s.into_response(), + CompletionResponder::Json(s) => Json(s).into_response(), + CompletionResponder::InternalError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) + } + CompletionResponder::ValidationError(e) => { + JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY) + } + CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response) + .to_response(http::StatusCode::INTERNAL_SERVER_ERROR), + } + } +} + +fn parse_request( + oairequest: CompletionRequest, + state: Arc, + tx: Sender, +) -> Request { + let repr = serde_json::to_string(&oairequest).unwrap(); + MistralRs::maybe_log_request(state.clone(), repr); + + let stop_toks = match oairequest.stop_seqs { + Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)), + Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])), + Some(StopTokens::MultiId(m)) => Some(InternalStopTokens::Ids(m)), + Some(StopTokens::SingleId(s)) => Some(InternalStopTokens::Ids(vec![s])), + None => None, + }; + + Request { + id: state.next_request_id(), + messages: Either::Right(oairequest.prompt), + sampling_params: SamplingParams { + temperature: oairequest.temperature, + top_k: oairequest.top_k, + top_p: oairequest.top_p, + top_n_logprobs: oairequest.logprobs.unwrap_or(1), + frequency_penalty: oairequest.frequency_penalty, + presence_penalty: oairequest.presence_penalty, + max_len: oairequest.max_tokens, + stop_toks, + logits_bias: oairequest.logit_bias, + n_choices: oairequest.n_choices, + }, + response: tx, + return_logprobs: oairequest.logprobs.is_some(), + is_streaming: oairequest.stream.unwrap_or(false), + + constraint: match oairequest.grammar { + Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), + Some(Grammar::Regex(regex)) => Constraint::Regex(regex), + None => Constraint::None, + }, + + request_type: RequestType::Completion, + } +} + +#[utoipa::path( + post, + tag = "Mistral.rs", + path = "/v1/completions", + request_body = CompletionRequest, + responses((status = 200, description = "Completions")) +)] +pub async fn completions( + State(state): State>, + Json(oairequest): Json, +) -> CompletionResponder { + let (tx, rx) = channel(); + let request = parse_request(oairequest, state.clone(), tx); + let is_streaming = request.is_streaming; + let sender = state.get_sender(); + sender.send(request).unwrap(); + + if is_streaming { + let streamer = Streamer { + rx, + is_done: false, + state, + }; + + CompletionResponder::Sse( + Sse::new(streamer).keep_alive( + KeepAlive::new() + .interval(Duration::from_millis( + env::var("KEEP_ALIVE_INTERVAL") + .map(|val| val.parse::().unwrap_or(1000)) + .unwrap_or(1000), + )) + .text("keep-alive-text"), + ), + ) + } else { + let response = rx.recv().unwrap(); + + match response { + Response::InternalError(e) => { + MistralRs::maybe_log_error(state, &*e); + CompletionResponder::InternalError(e) + } + Response::CompletionModelError(msg, response) => { + MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); + MistralRs::maybe_log_response(state, &response); + CompletionResponder::ModelError(msg, response) + } + Response::ValidationError(e) => CompletionResponder::ValidationError(e), + Response::CompletionDone(response) => { + MistralRs::maybe_log_response(state, &response); + CompletionResponder::Json(response) + } + Response::CompletionChunk(_) => unreachable!(), + Response::Chunk(_) => unreachable!(), + Response::Done(_) => unreachable!(), + Response::ModelError(_, _) => unreachable!(), + } + } +} From 249aa90b2d03844d4b7a9fa68e2c61c2f30ddcd2 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:38:22 -0400 Subject: [PATCH 03/15] Finalize impl for completion endpt, update example --- examples/server/{prompt.py => completion.py} | 6 +- mistralrs-core/src/engine/mod.rs | 97 +++++++++++++++----- mistralrs-core/src/request.rs | 3 +- mistralrs-core/src/response.rs | 19 +--- mistralrs-core/src/sequence.rs | 60 +++++++++++- mistralrs-core/src/utils/mod.rs | 81 +++++++++++----- mistralrs-pyo3/src/lib.rs | 1 + mistralrs-server/src/chat_completion.rs | 1 + mistralrs-server/src/completions.rs | 10 +- mistralrs-server/src/interactive_mode.rs | 1 + 10 files changed, 202 insertions(+), 77 deletions(-) rename examples/server/{prompt.py => completion.py} (92%) diff --git a/examples/server/prompt.py b/examples/server/completion.py similarity index 92% rename from examples/server/prompt.py rename to examples/server/completion.py index 0da715542..e95a19218 100644 --- a/examples/server/prompt.py +++ b/examples/server/completion.py @@ -39,15 +39,15 @@ def log_response(response: httpx.Response): while True: prompt = input(">>> ") - completion = openai.chat.completions.create( + completion = openai.completions.create( model="mistral", - messages=prompt, + prompt="Rust is a ", max_tokens=256, frequency_penalty=1.0, top_p=0.1, temperature=0, ) - resp = completion.choices[0].message.content + resp = completion.choices[0].text for eos in eos_toks: if resp.endswith(eos): out = resp[: -len(eos)] diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 383950d5e..c1a10cd20 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -7,7 +7,11 @@ use std::{ time::{Instant, SystemTime, UNIX_EPOCH}, }; -use crate::aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}; +use crate::{ + aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}, + response::CompletionChoice, + CompletionResponse, RequestType, +}; use candle_core::Tensor; use either::Either; use tracing::warn; @@ -149,7 +153,7 @@ impl Engine { seq.add_token(next_token.clone()); let is_done = seq.is_done(next_token_id, eos_tok, pipeline.get_max_seq_len()); // Handle streaming requests - if seq.get_mut_group().is_streaming { + if seq.get_mut_group().is_streaming && seq.get_mut_group().is_chat { let tokenizer = pipeline.tokenizer().clone(); if let Some(delta) = handle_seq_error!(seq.get_delta(&tokenizer), seq.responder()) { seq.add_streaming_chunk_choice_to_group(ChunkChoice { @@ -178,6 +182,22 @@ impl Engine { seq.get_mut_group() .maybe_send_streaming_response(seq, pipeline.name()); } + } else if seq.get_mut_group().is_streaming { + let tokenizer = pipeline.tokenizer().clone(); + if let Some(mut delta) = + handle_seq_error!(seq.get_delta(&tokenizer), seq.responder()) + { + let seq_is_done = is_done.is_some(); + if let Some(reason) = is_done { + seq.set_state(SequenceState::Done(reason)); + if let Some(ref suffix) = seq.suffix { + delta = delta + suffix; + } + } + + seq.get_mut_group() + .maybe_send_completion_streaming_response(seq, delta, seq_is_done); + } } else if let Some(reason) = is_done { Self::finish_seq(pipeline, seq, reason); pipeline.reset_non_granular_state(); @@ -215,30 +235,55 @@ impl Engine { seq.responder() ); - let choice = Choice { - stopreason: reason.to_string(), - index: seq.get_response_index(), - message: ResponseMessage { - content: res, - role: "assistant".to_string(), - }, - logprobs: logprobs.map(|l| Logprobs { content: Some(l) }), - }; - seq.add_choice_to_group(choice); - let group = seq.get_mut_group(); - group.maybe_send_done_response( - ChatCompletionResponse { - id: seq.id().to_string(), - choices: group.get_choices().to_vec(), - created: seq.creation_time(), - model: pipeline.name(), - system_fingerprint: SYSTEM_FINGERPRINT.to_string(), - object: "chat.completion".to_string(), - usage: group.get_usage(), - }, - seq.responder(), - ); + if group.is_chat { + let choice = Choice { + stopreason: reason.to_string(), + index: seq.get_response_index(), + message: ResponseMessage { + content: res, + role: "assistant".to_string(), + }, + logprobs: logprobs.map(|l| Logprobs { content: Some(l) }), + }; + seq.add_choice_to_group(choice); + } else { + let choice = CompletionChoice { + stopreason: reason.to_string(), + index: seq.get_response_index(), + text: res, + logprobs: None, + }; + seq.add_completion_choice_to_group(choice); + } + + if group.is_chat { + group.maybe_send_done_response( + ChatCompletionResponse { + id: seq.id().to_string(), + choices: group.get_choices().to_vec(), + created: seq.creation_time(), + model: pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "chat.completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ); + } else { + group.maybe_send_completion_done_response( + CompletionResponse { + id: seq.id().to_string(), + choices: group.get_completion_choices().to_vec(), + created: seq.creation_time(), + model: pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "text_completion".to_string(), + usage: group.get_usage(), + }, + seq.responder(), + ); + } } /// Clone the cache FROM the sequences' cache TO the model cache. Only used for completion seqs. @@ -480,6 +525,7 @@ impl Engine { let group = Rc::new(RefCell::new(SequenceGroup::new( request.sampling_params.n_choices, request.is_streaming, + request.request_type == RequestType::Chat, ))); let now = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -525,6 +571,7 @@ impl Engine { response_index, now.as_secs(), recognizer.clone(), + request.suffix.clone(), ); let seq = if let Some(prefill_cache) = prefill_cache.clone() { match prefill_cache { diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index fe21c1a80..9d458e1a6 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -10,7 +10,7 @@ pub enum Constraint { None, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum RequestType { Chat, Completion, @@ -25,6 +25,7 @@ pub struct Request { pub id: usize, pub constraint: Constraint, pub request_type: RequestType, + pub suffix: Option, } impl Debug for Request { diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index be6edb76c..809d1505f 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -85,26 +85,13 @@ pub struct ChatCompletionChunkResponse { pub object: String, } -#[derive(Debug, Clone, Serialize)] -pub struct CompletionResponseLogprob { - pub token: String, - pub logprob: f32, - pub bytes: Vec, - pub top_logprobs: Vec, -} - -#[derive(Debug, Clone, Serialize)] -pub struct CompletionLogprobs { - pub content: Option>, -} - #[derive(Debug, Clone, Serialize)] pub struct CompletionChoice { #[serde(rename = "finish_reason")] pub stopreason: String, pub index: usize, pub text: String, - pub logprobs: Option, + pub logprobs: Option<()>, } #[derive(Debug, Clone, Serialize)] @@ -127,11 +114,11 @@ pub struct CompletionChunkResponse { pub enum Response { InternalError(Box), ValidationError(Box), - // Chat specific + // Chat ModelError(String, ChatCompletionResponse), Done(ChatCompletionResponse), Chunk(ChatCompletionChunkResponse), - // Completion specific + // Completion CompletionModelError(String, CompletionResponse), CompletionDone(CompletionResponse), CompletionChunk(CompletionChunkResponse), diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 1d7e6e403..807e541b2 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -5,7 +5,11 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; -use crate::aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}; +use crate::{ + aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}, + response::{CompletionChoice, CompletionChunkResponse}, + CompletionResponse, +}; use crate::{ get_mut_group, models::LayerCaches, @@ -65,6 +69,7 @@ pub struct Sequence { response_index: usize, creation_time: u64, prefill_prompt_toks: Option>, + pub suffix: Option, // Cache scaling_cache: Option, @@ -104,6 +109,7 @@ impl Sequence { response_index: usize, creation_time: u64, recognizer: SequenceRecognizer, + suffix: Option, ) -> Self { let prompt_len = tokens.len(); Self { @@ -135,6 +141,7 @@ impl Sequence { creation_time, recognizer, prefill_prompt_toks: None, + suffix, } } @@ -303,9 +310,7 @@ impl Sequence { self.prompt_timestamp } - pub fn add_choice_to_group(&self, choice: Choice) { - get_mut_group!(self).choices.push(choice); - + fn update_time_info(&self) { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .expect("Time travel has occurred!") @@ -324,6 +329,16 @@ impl Sequence { get_mut_group!(self).total_sampling_time += self.total_sampling_time; } + pub fn add_choice_to_group(&self, choice: Choice) { + get_mut_group!(self).choices.push(choice); + self.update_time_info(); + } + + pub fn add_completion_choice_to_group(&self, choice: CompletionChoice) { + get_mut_group!(self).completion_choices.push(choice); + self.update_time_info(); + } + pub fn get_response_index(&self) -> usize { self.response_index } @@ -346,14 +361,17 @@ pub struct SequenceGroup { pub total_completion_time: u128, pub total_sampling_time: u128, choices: Vec, + completion_choices: Vec, pub streaming_chunks: Vec, pub is_streaming: bool, + pub is_chat: bool, } impl SequenceGroup { - pub fn new(n_choices: usize, is_streaming: bool) -> Self { + pub fn new(n_choices: usize, is_streaming: bool, is_chat: bool) -> Self { Self { choices: Vec::new(), + completion_choices: Vec::new(), n_choices, total_prompt_toks: 0, total_toks: 0, @@ -363,6 +381,7 @@ impl SequenceGroup { total_sampling_time: 0, streaming_chunks: Vec::new(), is_streaming, + is_chat, } } @@ -370,6 +389,10 @@ impl SequenceGroup { &self.choices } + pub fn get_completion_choices(&self) -> &[CompletionChoice] { + &self.completion_choices + } + pub fn get_usage(&self) -> Usage { #[allow(clippy::cast_precision_loss)] Usage { @@ -417,4 +440,31 @@ impl SequenceGroup { self.streaming_chunks.clear(); } } + + pub fn maybe_send_completion_done_response( + &self, + response: CompletionResponse, + sender: Sender, + ) { + if self.choices.len() == self.n_choices { + // NOTE(EricLBuehler): Unwrap reasoning: The receiver should really be there, otherwise it is their fault. + sender.send(Response::CompletionDone(response)).unwrap(); + } + } + + pub fn maybe_send_completion_streaming_response( + &mut self, + seq: &Sequence, + chunk: String, + is_done: bool, + ) { + if self.streaming_chunks.len() == self.n_choices && self.is_streaming { + seq.responder() + .send(Response::CompletionChunk(CompletionChunkResponse { + data: chunk, + done: is_done, + })) + .unwrap(); + } + } } diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index a4d982be5..16754b019 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -65,37 +65,68 @@ macro_rules! handle_pipeline_forward_error { Ok(v) => v, Err(_) => "".to_string(), }; - let choice = Choice { - stopreason: "error".to_string(), - index: seq.get_response_index(), - message: ResponseMessage { - content: res, - role: "assistant".to_string(), - }, - logprobs: None, - }; - seq.add_choice_to_group(choice); + + let group = seq.get_mut_group(); + if group.is_chat { + let choice = Choice { + stopreason: "error".to_string(), + index: seq.get_response_index(), + message: ResponseMessage { + content: res, + role: "assistant".to_string(), + }, + logprobs: None, + }; + seq.add_choice_to_group(choice); + } else { + let choice = CompletionChoice { + stopreason: "error".to_string(), + index: seq.get_response_index(), + text: res, + logprobs: None, + }; + seq.add_completion_choice_to_group(choice); + } } for seq in $seq_slice.iter_mut() { // Step 2: Respond with all groups let group = seq.get_mut_group(); - let partial_completion_response = ChatCompletionResponse { - id: seq.id().to_string(), - choices: group.get_choices().to_vec(), - created: seq.creation_time(), - model: $pipeline.name(), - system_fingerprint: SYSTEM_FINGERPRINT.to_string(), - object: "chat.completion".to_string(), - usage: group.get_usage(), - }; + if group.is_chat { + let partial_completion_response = ChatCompletionResponse { + id: seq.id().to_string(), + choices: group.get_choices().to_vec(), + created: seq.creation_time(), + model: $pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "chat.completion".to_string(), + usage: group.get_usage(), + }; + + seq.responder() + .send(Response::ModelError( + e.to_string(), + partial_completion_response + )) + .unwrap(); + } else { + let partial_completion_response = CompletionResponse { + id: seq.id().to_string(), + choices: group.get_completion_choices().to_vec(), + created: seq.creation_time(), + model: $pipeline.name(), + system_fingerprint: SYSTEM_FINGERPRINT.to_string(), + object: "text_completion".to_string(), + usage: group.get_usage(), + }; - seq.responder() - .send(Response::ModelError( - e.to_string(), - partial_completion_response - )) - .unwrap(); + seq.responder() + .send(Response::CompletionModelError( + e.to_string(), + partial_completion_response + )) + .unwrap(); + } } for seq in $seq_slice.iter_mut() { // Step 3: Set state - This cannot be done in Step 2 as `group` is locking the refcell diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 1dab93a7b..241df5aea 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -147,6 +147,7 @@ impl Runner { is_streaming: request.stream, constraint, request_type: RequestType::Chat, + suffix: None, }; MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index 761fa34d1..bdff00071 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -193,6 +193,7 @@ fn parse_request( response: tx, return_logprobs: oairequest.logprobs, is_streaming: oairequest.stream.unwrap_or(false), + suffix: None, constraint: match oairequest.grammar { Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index cffc8cf7c..48ce79661 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -26,6 +26,7 @@ use mistralrs_core::{ StopTokens as InternalStopTokens, }; use serde::Serialize; +use tracing::warn; #[derive(Debug)] struct ModelErrorMessage(String); @@ -159,6 +160,10 @@ fn parse_request( None => None, }; + if oairequest.logprobs.is_some() { + warn!("Completion requests do not support logprobs."); + } + Request { id: state.next_request_id(), messages: Either::Right(oairequest.prompt), @@ -166,7 +171,7 @@ fn parse_request( temperature: oairequest.temperature, top_k: oairequest.top_k, top_p: oairequest.top_p, - top_n_logprobs: oairequest.logprobs.unwrap_or(1), + top_n_logprobs: 1, frequency_penalty: oairequest.frequency_penalty, presence_penalty: oairequest.presence_penalty, max_len: oairequest.max_tokens, @@ -175,8 +180,9 @@ fn parse_request( n_choices: oairequest.n_choices, }, response: tx, - return_logprobs: oairequest.logprobs.is_some(), + return_logprobs: false, is_streaming: oairequest.stream.unwrap_or(false), + suffix: oairequest.suffix, constraint: match oairequest.grammar { Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 833b33144..ab0d1f836 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -47,6 +47,7 @@ pub fn interactive_mode(mistralrs: Arc) { is_streaming: true, constraint: Constraint::None, request_type: RequestType::Chat, + suffix: None, }; sender.send(req).unwrap(); From fe22991daa17d49c1ae989eac715962d50a7a021 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:40:05 -0400 Subject: [PATCH 04/15] Update example --- examples/server/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/completion.py b/examples/server/completion.py index e95a19218..33a5530d1 100644 --- a/examples/server/completion.py +++ b/examples/server/completion.py @@ -41,7 +41,7 @@ def log_response(response: httpx.Response): prompt = input(">>> ") completion = openai.completions.create( model="mistral", - prompt="Rust is a ", + prompt=prompt, max_tokens=256, frequency_penalty=1.0, top_p=0.1, From f698eb9303a885fee704e4d82b721fd06499e7c1 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:45:31 -0400 Subject: [PATCH 05/15] Fix a deadlock --- mistralrs-core/src/engine/mod.rs | 4 ++-- mistralrs-core/src/utils/mod.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index c1a10cd20..0ae833288 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -235,8 +235,7 @@ impl Engine { seq.responder() ); - let group = seq.get_mut_group(); - if group.is_chat { + if seq.get_mut_group().is_chat { let choice = Choice { stopreason: reason.to_string(), index: seq.get_response_index(), @@ -257,6 +256,7 @@ impl Engine { seq.add_completion_choice_to_group(choice); } + let group = seq.get_mut_group(); if group.is_chat { group.maybe_send_done_response( ChatCompletionResponse { diff --git a/mistralrs-core/src/utils/mod.rs b/mistralrs-core/src/utils/mod.rs index 16754b019..df3d97128 100644 --- a/mistralrs-core/src/utils/mod.rs +++ b/mistralrs-core/src/utils/mod.rs @@ -66,8 +66,7 @@ macro_rules! handle_pipeline_forward_error { Err(_) => "".to_string(), }; - let group = seq.get_mut_group(); - if group.is_chat { + if seq.get_mut_group().is_chat { let choice = Choice { stopreason: "error".to_string(), index: seq.get_response_index(), From b4442dde2735ca58cf2b7c66e24ea16a3a38ad09 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:49:57 -0400 Subject: [PATCH 06/15] Fix a bug and add a streaming example --- examples/server/completion_streaming.py | 21 +++++++++++++++++++++ examples/server/streaming.py | 2 +- mistralrs-core/src/sequence.rs | 4 ++-- 3 files changed, 24 insertions(+), 3 deletions(-) create mode 100644 examples/server/completion_streaming.py diff --git a/examples/server/completion_streaming.py b/examples/server/completion_streaming.py new file mode 100644 index 000000000..4130146a1 --- /dev/null +++ b/examples/server/completion_streaming.py @@ -0,0 +1,21 @@ +import openai + +openai.api_key = "EMPTY" + +openai.base_url = "http://localhost:1234/v1/" + +eos_toks = ["", "", "<|endoftext|>"] + +while True: + prompt = input(">>> ") + response = openai.completions.create( + model="mistral", + prompt=prompt, + max_tokens=256, + stream=True, + ) + resp = "" + for chunk in response: + delta = chunk.choices[0].delta.content + if delta not in eos_toks: + print(delta, end="") diff --git a/examples/server/streaming.py b/examples/server/streaming.py index b6a762a2f..ff493caf0 100644 --- a/examples/server/streaming.py +++ b/examples/server/streaming.py @@ -3,7 +3,7 @@ openai.api_key = "EMPTY" openai.base_url = "http://localhost:1234/v1/" -# """ + messages = [] prompt = input("Enter system prompt >>> ") if len(prompt) > 0: diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 807e541b2..f25307b0e 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -446,7 +446,7 @@ impl SequenceGroup { response: CompletionResponse, sender: Sender, ) { - if self.choices.len() == self.n_choices { + if self.completion_choices.len() == self.n_choices { // NOTE(EricLBuehler): Unwrap reasoning: The receiver should really be there, otherwise it is their fault. sender.send(Response::CompletionDone(response)).unwrap(); } @@ -458,7 +458,7 @@ impl SequenceGroup { chunk: String, is_done: bool, ) { - if self.streaming_chunks.len() == self.n_choices && self.is_streaming { + if self.is_streaming { seq.responder() .send(Response::CompletionChunk(CompletionChunkResponse { data: chunk, From 0c125074b5aff1093a07a34d77cd7b09f9ffd57b Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 14:50:42 -0400 Subject: [PATCH 07/15] Flush the streaming examples --- examples/server/completion_streaming.py | 2 ++ examples/server/streaming.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/server/completion_streaming.py b/examples/server/completion_streaming.py index 4130146a1..d25914971 100644 --- a/examples/server/completion_streaming.py +++ b/examples/server/completion_streaming.py @@ -1,4 +1,5 @@ import openai +import sys openai.api_key = "EMPTY" @@ -19,3 +20,4 @@ delta = chunk.choices[0].delta.content if delta not in eos_toks: print(delta, end="") + sys.stdout.flush() diff --git a/examples/server/streaming.py b/examples/server/streaming.py index ff493caf0..5fb0cd838 100644 --- a/examples/server/streaming.py +++ b/examples/server/streaming.py @@ -1,4 +1,5 @@ import openai +import sys openai.api_key = "EMPTY" @@ -25,6 +26,7 @@ delta = chunk.choices[0].delta.content if delta not in eos_toks: print(delta, end="") + sys.stdout.flush() resp += delta for eos in eos_toks: if resp.endswith(eos): From fc211c14b108230a21dbb3e4b491da7299470785 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 15:46:52 -0400 Subject: [PATCH 08/15] New receive pattern --- mistralrs-server/src/completions.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index 48ce79661..38b92b601 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -49,7 +49,7 @@ impl futures::Stream for Streamer { if self.is_done { return Poll::Ready(None); } - match self.rx.try_recv() { + match self.rx.recv() { Ok(resp) => match resp { Response::CompletionModelError(msg, _) => { MistralRs::maybe_log_error( From e8c7764a4c6f22d11044e63b94bbe6f9f7c0549c Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 16:24:34 -0400 Subject: [PATCH 09/15] Validate logprobs, streaming --- mistralrs-core/src/engine/mod.rs | 16 --- mistralrs-core/src/response.rs | 7 -- mistralrs-core/src/sequence.rs | 18 +-- mistralrs-pyo3/src/lib.rs | 1 - mistralrs-server/src/chat_completion.rs | 2 - mistralrs-server/src/completions.rs | 134 +++++++---------------- mistralrs-server/src/interactive_mode.rs | 1 - mistralrs-server/src/openai.rs | 4 +- 8 files changed, 40 insertions(+), 143 deletions(-) diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 0ae833288..e6aeb3050 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -182,22 +182,6 @@ impl Engine { seq.get_mut_group() .maybe_send_streaming_response(seq, pipeline.name()); } - } else if seq.get_mut_group().is_streaming { - let tokenizer = pipeline.tokenizer().clone(); - if let Some(mut delta) = - handle_seq_error!(seq.get_delta(&tokenizer), seq.responder()) - { - let seq_is_done = is_done.is_some(); - if let Some(reason) = is_done { - seq.set_state(SequenceState::Done(reason)); - if let Some(ref suffix) = seq.suffix { - delta = delta + suffix; - } - } - - seq.get_mut_group() - .maybe_send_completion_streaming_response(seq, delta, seq_is_done); - } } else if let Some(reason) = is_done { Self::finish_seq(pipeline, seq, reason); pipeline.reset_non_granular_state(); diff --git a/mistralrs-core/src/response.rs b/mistralrs-core/src/response.rs index 809d1505f..7bf864752 100644 --- a/mistralrs-core/src/response.rs +++ b/mistralrs-core/src/response.rs @@ -105,12 +105,6 @@ pub struct CompletionResponse { pub usage: Usage, } -#[derive(Debug, Clone, Serialize)] -pub struct CompletionChunkResponse { - pub data: String, - pub done: bool, -} - pub enum Response { InternalError(Box), ValidationError(Box), @@ -121,5 +115,4 @@ pub enum Response { // Completion CompletionModelError(String, CompletionResponse), CompletionDone(CompletionResponse), - CompletionChunk(CompletionChunkResponse), } diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index f25307b0e..63796e904 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx}, - response::{CompletionChoice, CompletionChunkResponse}, + response::CompletionChoice, CompletionResponse, }; use crate::{ @@ -451,20 +451,4 @@ impl SequenceGroup { sender.send(Response::CompletionDone(response)).unwrap(); } } - - pub fn maybe_send_completion_streaming_response( - &mut self, - seq: &Sequence, - chunk: String, - is_done: bool, - ) { - if self.is_streaming { - seq.responder() - .send(Response::CompletionChunk(CompletionChunkResponse { - data: chunk, - done: is_done, - })) - .unwrap(); - } - } } diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 241df5aea..2e30ed69f 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -165,7 +165,6 @@ impl Runner { } Response::ModelError(msg, _) => Err(PyValueError::new_err(msg.to_string())), Response::Chunk(_) => unreachable!(), - Response::CompletionChunk(_) => unreachable!(), Response::CompletionDone(_) => unreachable!(), Response::CompletionModelError(_, _) => unreachable!(), } diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index bdff00071..d411b25e3 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -73,7 +73,6 @@ impl futures::Stream for Streamer { Poll::Ready(Some(Event::default().json_data(response))) } Response::Done(_) => unreachable!(), - Response::CompletionChunk(_) => unreachable!(), Response::CompletionDone(_) => unreachable!(), Response::CompletionModelError(_, _) => unreachable!(), }, @@ -259,7 +258,6 @@ pub async fn chatcompletions( ChatCompletionResponder::Json(response) } Response::Chunk(_) => unreachable!(), - Response::CompletionChunk(_) => unreachable!(), Response::CompletionDone(_) => unreachable!(), Response::CompletionModelError(_, _) => unreachable!(), } diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index 38b92b601..5e1316601 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -1,24 +1,16 @@ use std::{ - env, error::Error, - pin::Pin, sync::{ - mpsc::{channel, Receiver, Sender}, + mpsc::{channel, Sender}, Arc, }, - task::{Context, Poll}, - time::Duration, }; use crate::openai::{CompletionRequest, Grammar, StopTokens}; -use anyhow::Result; use axum::{ extract::{Json, State}, http::{self, StatusCode}, - response::{ - sse::{Event, KeepAlive}, - IntoResponse, Sse, - }, + response::IntoResponse, }; use either::Either; use mistralrs_core::{ @@ -36,54 +28,7 @@ impl std::fmt::Display for ModelErrorMessage { } } impl std::error::Error for ModelErrorMessage {} -pub struct Streamer { - rx: Receiver, - is_done: bool, - state: Arc, -} - -impl futures::Stream for Streamer { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if self.is_done { - return Poll::Ready(None); - } - match self.rx.recv() { - Ok(resp) => match resp { - Response::CompletionModelError(msg, _) => { - MistralRs::maybe_log_error( - self.state.clone(), - &ModelErrorMessage(msg.to_string()), - ); - Poll::Ready(Some(Ok(Event::default().data(msg)))) - } - Response::ValidationError(e) => { - Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) - } - Response::InternalError(e) => { - MistralRs::maybe_log_error(self.state.clone(), &*e); - Poll::Ready(Some(Ok(Event::default().data(e.to_string())))) - } - Response::CompletionChunk(response) => { - if response.done { - self.is_done = true; - } - MistralRs::maybe_log_response(self.state.clone(), &response); - Poll::Ready(Some(Ok(Event::default().data(response.data)))) - } - Response::CompletionDone(_) => unreachable!(), - Response::Chunk(_) => unreachable!(), - Response::Done(_) => unreachable!(), - Response::ModelError(_, _) => unreachable!(), - }, - Err(_) => Poll::Pending, - } - } -} - pub enum CompletionResponder { - Sse(Sse), Json(CompletionResponse), ModelError(String, CompletionResponse), InternalError(Box), @@ -130,7 +75,6 @@ impl ErrorToResponse for JsonModelError {} impl IntoResponse for CompletionResponder { fn into_response(self) -> axum::response::Response { match self { - CompletionResponder::Sse(s) => s.into_response(), CompletionResponder::Json(s) => Json(s).into_response(), CompletionResponder::InternalError(e) => { JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR) @@ -164,6 +108,10 @@ fn parse_request( warn!("Completion requests do not support logprobs."); } + if oairequest._stream.is_some_and(|x| x) { + warn!("Completion requests do not support streaming."); + } + Request { id: state.next_request_id(), messages: Either::Right(oairequest.prompt), @@ -181,7 +129,7 @@ fn parse_request( }, response: tx, return_logprobs: false, - is_streaming: oairequest.stream.unwrap_or(false), + is_streaming: false, suffix: oairequest.suffix, constraint: match oairequest.grammar { @@ -209,48 +157,40 @@ pub async fn completions( let request = parse_request(oairequest, state.clone(), tx); let is_streaming = request.is_streaming; let sender = state.get_sender(); - sender.send(request).unwrap(); + + if request.return_logprobs { + return CompletionResponder::ValidationError( + "Completion requests do not support logprobs.".into(), + ); + } if is_streaming { - let streamer = Streamer { - rx, - is_done: false, - state, - }; + return CompletionResponder::ValidationError( + "Completion requests do not support streaming.".into(), + ); + } - CompletionResponder::Sse( - Sse::new(streamer).keep_alive( - KeepAlive::new() - .interval(Duration::from_millis( - env::var("KEEP_ALIVE_INTERVAL") - .map(|val| val.parse::().unwrap_or(1000)) - .unwrap_or(1000), - )) - .text("keep-alive-text"), - ), - ) - } else { - let response = rx.recv().unwrap(); + sender.send(request).unwrap(); - match response { - Response::InternalError(e) => { - MistralRs::maybe_log_error(state, &*e); - CompletionResponder::InternalError(e) - } - Response::CompletionModelError(msg, response) => { - MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); - MistralRs::maybe_log_response(state, &response); - CompletionResponder::ModelError(msg, response) - } - Response::ValidationError(e) => CompletionResponder::ValidationError(e), - Response::CompletionDone(response) => { - MistralRs::maybe_log_response(state, &response); - CompletionResponder::Json(response) - } - Response::CompletionChunk(_) => unreachable!(), - Response::Chunk(_) => unreachable!(), - Response::Done(_) => unreachable!(), - Response::ModelError(_, _) => unreachable!(), + let response = rx.recv().unwrap(); + + match response { + Response::InternalError(e) => { + MistralRs::maybe_log_error(state, &*e); + CompletionResponder::InternalError(e) + } + Response::CompletionModelError(msg, response) => { + MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string())); + MistralRs::maybe_log_response(state, &response); + CompletionResponder::ModelError(msg, response) + } + Response::ValidationError(e) => CompletionResponder::ValidationError(e), + Response::CompletionDone(response) => { + MistralRs::maybe_log_response(state, &response); + CompletionResponder::Json(response) } + Response::Chunk(_) => unreachable!(), + Response::Done(_) => unreachable!(), + Response::ModelError(_, _) => unreachable!(), } } diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index ab0d1f836..069de49bd 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -82,7 +82,6 @@ pub fn interactive_mode(mistralrs: Arc) { break 'outer; } Response::Done(_) => unreachable!(), - Response::CompletionChunk(_) => unreachable!(), Response::CompletionDone(_) => unreachable!(), Response::CompletionModelError(_, _) => unreachable!(), } diff --git a/mistralrs-server/src/openai.rs b/mistralrs-server/src/openai.rs index 60c467f9f..14bef215c 100644 --- a/mistralrs-server/src/openai.rs +++ b/mistralrs-server/src/openai.rs @@ -122,8 +122,8 @@ pub struct CompletionRequest { #[serde(rename = "stop")] #[schema(example = json!(Option::None::))] pub stop_seqs: Option, - #[schema(example = true)] - pub stream: Option, + #[serde(rename = "stream")] + pub _stream: Option, #[schema(example = 0.7)] pub temperature: Option, #[schema(example = json!(Option::None::))] From 51c2d9e2ef2294361cdd648db8270645a3c0aa17 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 16:38:27 -0400 Subject: [PATCH 10/15] Add to python bindings --- mistralrs-pyo3/mistralrs.pyi | 25 ++++++ mistralrs-pyo3/src/lib.rs | 146 +++++++++++++++++++++++++++++++++++ 2 files changed, 171 insertions(+) diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index 852fa60ce..a296595d9 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -25,6 +25,31 @@ class ChatCompletionRequest: grammar: str | None = None grammar_type: str | None = None +@dataclass +class CompletionRequest: + """ + A CompletionRequest represents a request sent to the mistral.rs engine. It encodes information + about input data, sampling, and how to return the response. + """ + + prompt: str + model: str + best_of: int + echo_prompt: bool + logit_bias: dict[int, float] | None = None + max_tokens: int | None = None + n_choices: int = 1 + presence_penalty: float | None = None + frequency_penalty: float | None = None + stop_token_ids: list[int] | None = None + temperature: float | None = None + top_p: float | None = None + stream: bool = False + top_k: int | None = None + suffix: str | None = None + grammar: str | None = None + grammar_type: str | None = None + class Runner: """ The Runner is a class with no constructor. It is only created via one of the loader classes. diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 2e30ed69f..ac02b36fc 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -170,6 +170,152 @@ impl Runner { } }) } + + /// Send an OpenAI API compatible request, returning raw JSON. + fn send_completion_request(&mut self, request: Py) -> PyResult { + let (tx, rx) = channel(); + Python::with_gil(|py| { + let request = request.bind(py).borrow(); + let stop_toks = request + .stop_token_ids + .as_ref() + .map(|x| StopTokens::Ids(x.to_vec())); + let constraint = if request.grammar_type == Some("regex".to_string()) { + if request.grammar.is_none() { + return Err(PyValueError::new_err( + "Grammar type is specified but not grammar text", + )); + } + Constraint::Regex(request.grammar.as_ref().unwrap().clone()) + } else if request.grammar_type == Some("yacc".to_string()) { + if request.grammar.is_none() { + return Err(PyValueError::new_err( + "Grammar type is specified but not grammar text", + )); + } + Constraint::Yacc(request.grammar.as_ref().unwrap().clone()) + } else if request.grammar_type.is_some() { + return Err(PyValueError::new_err( + "Grammar type is specified but is not `regex` or `yacc`", + )); + } else { + Constraint::None + }; + let model_request = _Request { + id: { + let l = NEXT_REQUEST_ID.lock().unwrap(); + let last = &mut *l.borrow_mut(); + let last_v = *last; + *last += 1; + last_v + }, + messages: Either::Right(request.prompt.clone()), + sampling_params: SamplingParams { + temperature: request.temperature, + top_k: request.top_k, + top_p: request.top_p, + top_n_logprobs: 1, + frequency_penalty: request.frequency_penalty, + presence_penalty: request.presence_penalty, + max_len: request.max_tokens, + stop_toks, + logits_bias: request.logit_bias.clone(), + n_choices: request.n_choices, + }, + response: tx, + return_logprobs: false, + is_streaming: false, + constraint, + request_type: RequestType::Completion, + suffix: request.suffix.clone(), + }; + + MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); + let sender = self.runner.get_sender(); + sender.send(model_request).unwrap(); + let response = rx.recv().unwrap(); + + match response { + Response::ValidationError(e) | Response::InternalError(e) => { + Err(PyValueError::new_err(e.to_string())) + } + Response::Done(response) => { + MistralRs::maybe_log_response(self.runner.clone(), &response); + Ok(serde_json::to_string(&response).unwrap()) + } + Response::ModelError(msg, _) => Err(PyValueError::new_err(msg.to_string())), + Response::Chunk(_) => unreachable!(), + Response::CompletionDone(_) => unreachable!(), + Response::CompletionModelError(_, _) => unreachable!(), + } + }) + } +} + +#[pyclass] +#[derive(Debug)] +/// An OpenAI API compatible completion request. +struct CompletionRequest { + _model: String, + prompt: String, + best_of: usize, + echo_prompt: bool, + presence_penalty: Option, + frequency_penalty: Option, + logit_bias: Option>, + max_tokens: Option, + n_choices: usize, + stop_token_ids: Option>, + temperature: Option, + top_p: Option, + suffix: Option, + top_k: Option, + grammar: Option, + grammar_type: Option, +} + +#[pymethods] +impl CompletionRequest { + #[new] + #[pyo3(signature = (prompt, model, best_of = 1, echo_prompt = false, presence_penalty=None,frequency_penalty=None,logit_bias=None,max_tokens=None,n_choices=1,stop_token_ids=None,temperature=None,top_p=None,suffix=None,top_k=None, grammar = None, grammar_type = None))] + #[allow(clippy::too_many_arguments)] + fn new( + prompt: String, + model: String, + best_of: usize, + echo_prompt: bool, + presence_penalty: Option, + frequency_penalty: Option, + logit_bias: Option>, + max_tokens: Option, + n_choices: usize, + stop_token_ids: Option>, + temperature: Option, + top_p: Option, + suffix: Option, + top_k: Option, + grammar: Option, + grammar_type: Option, + ) -> PyResult { + Ok(Self { + prompt, + best_of, + echo_prompt, + suffix, + _model: model, + logit_bias, + max_tokens, + n_choices, + presence_penalty, + frequency_penalty, + stop_token_ids, + temperature, + top_p, + top_k, + grammar, + grammar_type, + }) + } } #[pyclass] From 6c2d76c230f41e60ac570b938d1f1e8bc4e53ba0 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 16:54:12 -0400 Subject: [PATCH 11/15] Implement echo_prompt --- mistralrs-core/src/engine/mod.rs | 9 +++++++++ mistralrs-core/src/request.rs | 4 ++-- mistralrs-core/src/sequence.rs | 3 +++ mistralrs-pyo3/src/lib.rs | 4 +++- mistralrs-server/src/completions.rs | 4 +++- 5 files changed, 20 insertions(+), 4 deletions(-) diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index e6aeb3050..030fbebd8 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -556,6 +556,15 @@ impl Engine { now.as_secs(), recognizer.clone(), request.suffix.clone(), + if let RequestType::Completion { echo_prompt } = request.request_type.clone() { + if echo_prompt { + Some(formatted_prompt.clone()) + } else { + None + } + } else { + None + }, ); let seq = if let Some(prefill_cache) = prefill_cache.clone() { match prefill_cache { diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 9d458e1a6..5b8258777 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -10,10 +10,10 @@ pub enum Constraint { None, } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum RequestType { Chat, - Completion, + Completion { echo_prompt: bool }, } pub struct Request { diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index 63796e904..be166cde5 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -70,6 +70,7 @@ pub struct Sequence { creation_time: u64, prefill_prompt_toks: Option>, pub suffix: Option, + pub prefix: Option, // Cache scaling_cache: Option, @@ -110,6 +111,7 @@ impl Sequence { creation_time: u64, recognizer: SequenceRecognizer, suffix: Option, + prefix: Option, ) -> Self { let prompt_len = tokens.len(); Self { @@ -142,6 +144,7 @@ impl Sequence { recognizer, prefill_prompt_toks: None, suffix, + prefix, } } diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index ac02b36fc..486a677f9 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -226,7 +226,9 @@ impl Runner { return_logprobs: false, is_streaming: false, constraint, - request_type: RequestType::Completion, + request_type: RequestType::Completion { + echo_prompt: request.echo_prompt, + }, suffix: request.suffix.clone(), }; diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index 5e1316601..091446df5 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -138,7 +138,9 @@ fn parse_request( None => Constraint::None, }, - request_type: RequestType::Completion, + request_type: RequestType::Completion { + echo_prompt: oairequest.echo_prompt, + }, } } From 4a6091a0897ead83d9320d58bfee8dd0b6e7b75d Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 20:11:16 -0400 Subject: [PATCH 12/15] Implement best_of for completion --- mistralrs-core/src/engine/mod.rs | 1 + mistralrs-core/src/request.rs | 1 + mistralrs-core/src/sequence.rs | 31 ++++++++++++++++++++---- mistralrs-pyo3/mistralrs.pyi | 1 + mistralrs-pyo3/src/lib.rs | 2 ++ mistralrs-server/src/chat_completion.rs | 1 + mistralrs-server/src/completions.rs | 1 + mistralrs-server/src/interactive_mode.rs | 1 + 8 files changed, 34 insertions(+), 5 deletions(-) diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 030fbebd8..24f170881 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -510,6 +510,7 @@ impl Engine { request.sampling_params.n_choices, request.is_streaming, request.request_type == RequestType::Chat, + request.best_of, ))); let now = SystemTime::now() .duration_since(UNIX_EPOCH) diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 5b8258777..c2aad8c31 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -26,6 +26,7 @@ pub struct Request { pub constraint: Constraint, pub request_type: RequestType, pub suffix: Option, + pub best_of: Option, } impl Debug for Request { diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index be166cde5..606ffcfb3 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -81,6 +81,7 @@ pub struct Sequence { tokens: Vec, decoded_tokens: Option>, logprobs: Vec, + cumulative_logprob: f32, // GPU things pub prompt_tok_per_sec: f32, @@ -145,6 +146,7 @@ impl Sequence { prefill_prompt_toks: None, suffix, prefix, + cumulative_logprob: 0., } } @@ -222,6 +224,7 @@ impl Sequence { } pub fn add_token(&mut self, tok: Logprobs) { + self.cumulative_logprob += tok.logprob; self.tokens.push(tok.token); self.logprobs.push(tok); } @@ -338,7 +341,9 @@ impl Sequence { } pub fn add_completion_choice_to_group(&self, choice: CompletionChoice) { - get_mut_group!(self).completion_choices.push(choice); + get_mut_group!(self) + .completion_choices + .push((self.cumulative_logprob, choice)); self.update_time_info(); } @@ -357,6 +362,7 @@ impl Sequence { pub struct SequenceGroup { n_choices: usize, // The target number of choices to return. Can be decreased if an error is thrown. + best_of: Option, // Top n seqs based on cumulative logprobs. pub total_prompt_toks: usize, pub total_toks: usize, pub total_prompt_time: u128, @@ -364,14 +370,19 @@ pub struct SequenceGroup { pub total_completion_time: u128, pub total_sampling_time: u128, choices: Vec, - completion_choices: Vec, + completion_choices: Vec<(f32, CompletionChoice)>, pub streaming_chunks: Vec, pub is_streaming: bool, pub is_chat: bool, } impl SequenceGroup { - pub fn new(n_choices: usize, is_streaming: bool, is_chat: bool) -> Self { + pub fn new( + n_choices: usize, + is_streaming: bool, + is_chat: bool, + best_of: Option, + ) -> Self { Self { choices: Vec::new(), completion_choices: Vec::new(), @@ -385,15 +396,25 @@ impl SequenceGroup { streaming_chunks: Vec::new(), is_streaming, is_chat, + best_of, } } + /// This does not apply best_of. pub fn get_choices(&self) -> &[Choice] { &self.choices } - pub fn get_completion_choices(&self) -> &[CompletionChoice] { - &self.completion_choices + /// This applies the best_of. + pub fn get_completion_choices(&self) -> Vec { + let mut choices = self.completion_choices.clone(); + // Sort by descending logprobs + choices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap()); + choices + .into_iter() + .take(self.best_of.unwrap()) + .map(|(_, x)| x) + .collect::>() } pub fn get_usage(&self) -> Usage { diff --git a/mistralrs-pyo3/mistralrs.pyi b/mistralrs-pyo3/mistralrs.pyi index a296595d9..ab37ad7ba 100644 --- a/mistralrs-pyo3/mistralrs.pyi +++ b/mistralrs-pyo3/mistralrs.pyi @@ -39,6 +39,7 @@ class CompletionRequest: logit_bias: dict[int, float] | None = None max_tokens: int | None = None n_choices: int = 1 + best_of: int = 1 presence_penalty: float | None = None frequency_penalty: float | None = None stop_token_ids: list[int] | None = None diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 486a677f9..59d0b1193 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -148,6 +148,7 @@ impl Runner { constraint, request_type: RequestType::Chat, suffix: None, + best_of: None, }; MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); @@ -230,6 +231,7 @@ impl Runner { echo_prompt: request.echo_prompt, }, suffix: request.suffix.clone(), + best_of: Some(request.best_of), }; MistralRs::maybe_log_request(self.runner.clone(), format!("{request:?}")); diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index d411b25e3..053763e39 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -193,6 +193,7 @@ fn parse_request( return_logprobs: oairequest.logprobs, is_streaming: oairequest.stream.unwrap_or(false), suffix: None, + best_of: None, constraint: match oairequest.grammar { Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index 091446df5..a19feb55b 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -131,6 +131,7 @@ fn parse_request( return_logprobs: false, is_streaming: false, suffix: oairequest.suffix, + best_of: Some(oairequest.best_of), constraint: match oairequest.grammar { Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), diff --git a/mistralrs-server/src/interactive_mode.rs b/mistralrs-server/src/interactive_mode.rs index 069de49bd..82553f6c2 100644 --- a/mistralrs-server/src/interactive_mode.rs +++ b/mistralrs-server/src/interactive_mode.rs @@ -48,6 +48,7 @@ pub fn interactive_mode(mistralrs: Arc) { constraint: Constraint::None, request_type: RequestType::Chat, suffix: None, + best_of: None, }; sender.send(req).unwrap(); From 73f6ae5d02d123eeb1511634e02702935796ab8d Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 20:19:00 -0400 Subject: [PATCH 13/15] Update --- examples/server/completion_streaming.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 examples/server/completion_streaming.py diff --git a/examples/server/completion_streaming.py b/examples/server/completion_streaming.py deleted file mode 100644 index d25914971..000000000 --- a/examples/server/completion_streaming.py +++ /dev/null @@ -1,23 +0,0 @@ -import openai -import sys - -openai.api_key = "EMPTY" - -openai.base_url = "http://localhost:1234/v1/" - -eos_toks = ["", "", "<|endoftext|>"] - -while True: - prompt = input(">>> ") - response = openai.completions.create( - model="mistral", - prompt=prompt, - max_tokens=256, - stream=True, - ) - resp = "" - for chunk in response: - delta = chunk.choices[0].delta.content - if delta not in eos_toks: - print(delta, end="") - sys.stdout.flush() From 372f843bb6236ec2273177c42fb3fd437751826e Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 20:29:46 -0400 Subject: [PATCH 14/15] Update docs --- examples/http.md | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/examples/http.md b/examples/http.md index 433654385..c9b303c26 100644 --- a/examples/http.md +++ b/examples/http.md @@ -74,6 +74,45 @@ Example with `curl`: curl http://localhost:/docs ``` +## `POST`: `/v1/completions` +Process an OpenAI compatible completions request, returning an OpenAI compatible response when finished. Please find the official OpenAI API documentation [here](https://platform.openai.com/docs/api-reference/completions). + +To send a request with the Python `openai` library: + +```python +import openai + +client = openai.OpenAI( + base_url="http://localhost:8080/v1", # "http://:port" + api_key = "EMPTY" +) + +completion = client.completions.create( + model="mistral", + prompt="What is Rust?", + max_tokens=256, + frequency_penalty=1.0, + top_p=0.1, + temperature=0, +) + +print(completion.choices[0].message) +``` + +Or with `curl`: +```bash +curl http://localhost:8080/v1/chat/completions \ +-H "Content-Type: application/json" \ +-H "Authorization: Bearer EMPTY" \ +-d '{ +"model": "", +"prompt": "What is Rust"? +] +}' +``` + +Streaming requests are not supported. + ## Request ### `ChatCompletionRequest` OpenAI compatible request. From 78ef30491ae63dbeb13f61a2ebbc98d402ea77f3 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Thu, 11 Apr 2024 20:30:35 -0400 Subject: [PATCH 15/15] Update docs --- examples/http.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/http.md b/examples/http.md index c9b303c26..2c1e04253 100644 --- a/examples/http.md +++ b/examples/http.md @@ -106,7 +106,7 @@ curl http://localhost:8080/v1/chat/completions \ -H "Authorization: Bearer EMPTY" \ -d '{ "model": "", -"prompt": "What is Rust"? +"prompt": "What is Rust?" ] }' ```