Skip to content

Commit

Permalink
Merge pull request #107 from EricLBuehler/completion_api
Browse files Browse the repository at this point in the history
Add the /v1/completion endpoint
  • Loading branch information
EricLBuehler authored Apr 12, 2024
2 parents 3e169d7 + 78ef304 commit a36d4a1
Show file tree
Hide file tree
Showing 20 changed files with 1,001 additions and 344 deletions.
45 changes: 42 additions & 3 deletions examples/http.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,45 @@ Example with `curl`:
curl http://localhost:<port>/docs
```

## `POST`: `/v1/completions`
Process an OpenAI compatible completions request, returning an OpenAI compatible response when finished. Please find the official OpenAI API documentation [here](https://platform.openai.com/docs/api-reference/completions).

To send a request with the Python `openai` library:

```python
import openai

client = openai.OpenAI(
base_url="http://localhost:8080/v1", # "http://<Your api-server IP>:port"
api_key = "EMPTY"
)

completion = client.completions.create(
model="mistral",
prompt="What is Rust?",
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
)

print(completion.choices[0].message)
```

Or with `curl`:
```bash
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer EMPTY" \
-d '{
"model": "",
"prompt": "What is Rust?"
]
}'
```

Streaming requests are not supported.

## Request
### `ChatCompletionRequest`
OpenAI compatible request.
Expand Down Expand Up @@ -134,7 +173,7 @@ pub struct ChatCompletionResponse {
pub model: &'static str,
pub system_fingerprint: String,
pub object: String,
pub usage: ChatCompletionUsage,
pub usage: Usage,
}
```

Expand Down Expand Up @@ -186,9 +225,9 @@ pub struct TopLogprob {
}
```

### `ChatCompletionUsage`
### `Usage`
```rust
pub struct ChatCompletionUsage {
pub struct Usage {
pub completion_tokens: usize,
pub prompt_tokens: usize,
pub total_tokens: usize,
Expand Down
6 changes: 3 additions & 3 deletions examples/server/prompt.py → examples/server/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def log_response(response: httpx.Response):

while True:
prompt = input(">>> ")
completion = openai.chat.completions.create(
completion = openai.completions.create(
model="mistral",
messages=prompt,
prompt=prompt,
max_tokens=256,
frequency_penalty=1.0,
top_p=0.1,
temperature=0,
)
resp = completion.choices[0].message.content
resp = completion.choices[0].text
for eos in eos_toks:
if resp.endswith(eos):
out = resp[: -len(eos)]
Expand Down
4 changes: 3 additions & 1 deletion examples/server/streaming.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import openai
import sys

openai.api_key = "EMPTY"

openai.base_url = "http://localhost:1234/v1/"
# """

messages = []
prompt = input("Enter system prompt >>> ")
if len(prompt) > 0:
Expand All @@ -25,6 +26,7 @@
delta = chunk.choices[0].delta.content
if delta not in eos_toks:
print(delta, end="")
sys.stdout.flush()
resp += delta
for eos in eos_toks:
if resp.endswith(eos):
Expand Down
4 changes: 2 additions & 2 deletions integrations/llama_index_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
top_k=self.generate_kwargs["top_k"],
top_p=self.generate_kwargs["top_p"],
presence_penalty=self.generate_kwargs.get("presence_penalty", None),
repetition_penalty=self.generate_kwargs.get("repetition_penalty", None),
frequency_penalty=self.generate_kwargs.get("frequency_penalty", None),
temperature=self.generate_kwargs.get("temperature", None),
)
completion_response = self._runner.send_chat_completion_request(request)
Expand Down Expand Up @@ -399,7 +399,7 @@ def complete(
top_k=self.generate_kwargs["top_k"],
top_p=self.generate_kwargs["top_p"],
presence_penalty=self.generate_kwargs.get("presence_penalty", None),
repetition_penalty=self.generate_kwargs.get("repetition_penalty", None),
frequency_penalty=self.generate_kwargs.get("frequency_penalty", None),
temperature=self.generate_kwargs.get("temperature", None),
)
completion_response = self._runner.send_chat_completion_request(request)
Expand Down
89 changes: 65 additions & 24 deletions mistralrs-core/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ use std::{
time::{Instant, SystemTime, UNIX_EPOCH},
};

use crate::aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx};
use crate::{
aici::{cfg::CfgParser, recognizer::StackRecognizer, rx::RecRx},
response::CompletionChoice,
CompletionResponse, RequestType,
};
use candle_core::Tensor;
use either::Either;
use tracing::warn;
Expand Down Expand Up @@ -149,7 +153,7 @@ impl Engine {
seq.add_token(next_token.clone());
let is_done = seq.is_done(next_token_id, eos_tok, pipeline.get_max_seq_len());
// Handle streaming requests
if seq.get_mut_group().is_streaming {
if seq.get_mut_group().is_streaming && seq.get_mut_group().is_chat {
let tokenizer = pipeline.tokenizer().clone();
if let Some(delta) = handle_seq_error!(seq.get_delta(&tokenizer), seq.responder()) {
seq.add_streaming_chunk_choice_to_group(ChunkChoice {
Expand Down Expand Up @@ -215,30 +219,55 @@ impl Engine {
seq.responder()
);

let choice = Choice {
stopreason: reason.to_string(),
index: seq.get_response_index(),
message: ResponseMessage {
content: res,
role: "assistant".to_string(),
},
logprobs: logprobs.map(|l| Logprobs { content: Some(l) }),
};
seq.add_choice_to_group(choice);
if seq.get_mut_group().is_chat {
let choice = Choice {
stopreason: reason.to_string(),
index: seq.get_response_index(),
message: ResponseMessage {
content: res,
role: "assistant".to_string(),
},
logprobs: logprobs.map(|l| Logprobs { content: Some(l) }),
};
seq.add_choice_to_group(choice);
} else {
let choice = CompletionChoice {
stopreason: reason.to_string(),
index: seq.get_response_index(),
text: res,
logprobs: None,
};
seq.add_completion_choice_to_group(choice);
}

let group = seq.get_mut_group();
group.maybe_send_done_response(
ChatCompletionResponse {
id: seq.id().to_string(),
choices: group.get_choices().to_vec(),
created: seq.creation_time(),
model: pipeline.name(),
system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
object: "chat.completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
);
if group.is_chat {
group.maybe_send_done_response(
ChatCompletionResponse {
id: seq.id().to_string(),
choices: group.get_choices().to_vec(),
created: seq.creation_time(),
model: pipeline.name(),
system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
object: "chat.completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
);
} else {
group.maybe_send_completion_done_response(
CompletionResponse {
id: seq.id().to_string(),
choices: group.get_completion_choices().to_vec(),
created: seq.creation_time(),
model: pipeline.name(),
system_fingerprint: SYSTEM_FINGERPRINT.to_string(),
object: "text_completion".to_string(),
usage: group.get_usage(),
},
seq.responder(),
);
}
}

/// Clone the cache FROM the sequences' cache TO the model cache. Only used for completion seqs.
Expand Down Expand Up @@ -480,6 +509,8 @@ impl Engine {
let group = Rc::new(RefCell::new(SequenceGroup::new(
request.sampling_params.n_choices,
request.is_streaming,
request.request_type == RequestType::Chat,
request.best_of,
)));
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
Expand Down Expand Up @@ -525,6 +556,16 @@ impl Engine {
response_index,
now.as_secs(),
recognizer.clone(),
request.suffix.clone(),
if let RequestType::Completion { echo_prompt } = request.request_type.clone() {
if echo_prompt {
Some(formatted_prompt.clone())
} else {
None
}
} else {
None
},
);
let seq = if let Some(prefill_cache) = prefill_cache.clone() {
match prefill_cache {
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ pub use pipeline::{
MistralSpecificConfig, MixtralLoader, MixtralSpecificConfig, ModelKind, Phi2Loader,
Phi2SpecificConfig, TokenSource,
};
pub use request::{Constraint, Request};
pub use request::{Constraint, Request, RequestType};
pub use response::Response;
pub use response::{ChatCompletionResponse, ChatCompletionUsage};
pub use response::{ChatCompletionResponse, CompletionResponse, Usage};
pub use sampler::{SamplingParams, StopTokens};
pub use scheduler::SchedulerMethod;
use serde::Serialize;
Expand Down
13 changes: 11 additions & 2 deletions mistralrs-core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ pub enum Constraint {
None,
}

#[derive(Debug, PartialEq, Clone)]
pub enum RequestType {
Chat,
Completion { echo_prompt: bool },
}

pub struct Request {
pub messages: Either<Vec<IndexMap<String, String>>, String>,
pub sampling_params: SamplingParams,
Expand All @@ -18,14 +24,17 @@ pub struct Request {
pub is_streaming: bool,
pub id: usize,
pub constraint: Constraint,
pub request_type: RequestType,
pub suffix: Option<String>,
pub best_of: Option<usize>,
}

impl Debug for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Request {} {{ messages: `{:?}`, sampling_params: {:?}}}",
self.id, self.messages, self.sampling_params
"Request {} ({:?}) {{ messages: `{:?}`, sampling_params: {:?}}}",
self.id, self.request_type, self.messages, self.sampling_params
)
}
}
28 changes: 26 additions & 2 deletions mistralrs-core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub struct ChunkChoice {
}

#[derive(Debug, Clone, Serialize)]
pub struct ChatCompletionUsage {
pub struct Usage {
pub completion_tokens: usize,
pub prompt_tokens: usize,
pub total_tokens: usize,
Expand All @@ -72,7 +72,7 @@ pub struct ChatCompletionResponse {
pub model: String,
pub system_fingerprint: String,
pub object: String,
pub usage: ChatCompletionUsage,
pub usage: Usage,
}

#[derive(Debug, Clone, Serialize)]
Expand All @@ -85,10 +85,34 @@ pub struct ChatCompletionChunkResponse {
pub object: String,
}

#[derive(Debug, Clone, Serialize)]
pub struct CompletionChoice {
#[serde(rename = "finish_reason")]
pub stopreason: String,
pub index: usize,
pub text: String,
pub logprobs: Option<()>,
}

#[derive(Debug, Clone, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub choices: Vec<CompletionChoice>,
pub created: u64,
pub model: String,
pub system_fingerprint: String,
pub object: String,
pub usage: Usage,
}

pub enum Response {
InternalError(Box<dyn Error + Send + Sync>),
ValidationError(Box<dyn Error + Send + Sync>),
// Chat
ModelError(String, ChatCompletionResponse),
Done(ChatCompletionResponse),
Chunk(ChatCompletionChunkResponse),
// Completion
CompletionModelError(String, CompletionResponse),
CompletionDone(CompletionResponse),
}
Loading

0 comments on commit a36d4a1

Please sign in to comment.