Skip to content

Commit

Permalink
feat: allow tokenizing in batch (#6217)
Browse files Browse the repository at this point in the history
* feat: allow tokenizing in batch

* tokenize batch endpoint

* lib code

* use in front

* nits

---------

Co-authored-by: Henry Fontanier <[email protected]>
  • Loading branch information
fontanierh and Henry Fontanier authored Jul 15, 2024
1 parent 77bda6e commit 8168475
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 98 deletions.
93 changes: 80 additions & 13 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,22 +1179,30 @@ async fn data_sources_tokenize(
let embedder_config = ds.embedder_config().clone();
let embedder =
provider(embedder_config.provider_id).embedder(embedder_config.model_id);
match embedder.tokenize(&payload.text).await {
match embedder.tokenize(vec![payload.text]).await {
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to tokenize text",
Some(e),
),
Ok(tokens) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"tokens": tokens,
})),
}),
),
Ok(mut res) => match res.pop() {
None => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to tokenize text",
None,
),
Some(tokens) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"tokens": tokens,
})),
}),
),
},
}
}
},
Expand Down Expand Up @@ -2407,19 +2415,77 @@ async fn tokenize(Json(payload): Json<TokenizePayload>) -> (StatusCode, Json<API
None => (),
}

match llm.tokenize(&payload.text).await {
match llm.tokenize(vec![payload.text]).await {
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to tokenize text",
Some(e),
),
Ok(mut res) => match res.pop() {
None => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to tokenize text",
None,
),
Some(tokens) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"tokens": tokens,
})),
}),
),
},
}
}

#[derive(serde::Deserialize)]
struct TokenizeBatchPayload {
texts: Vec<String>,
provider_id: ProviderID,
model_id: String,
credentials: Option<run::Credentials>,
}

async fn tokenize_batch(
Json(payload): Json<TokenizeBatchPayload>,
) -> (StatusCode, Json<APIResponse>) {
let mut llm = provider(payload.provider_id).llm(payload.model_id);

// If we received credentials we initialize the llm with them.
match payload.credentials {
Some(c) => {
match llm.initialize(c.clone()).await {
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to initialize LLM",
Some(e),
);
}
Ok(()) => (),
};
}
None => (),
}

match llm.tokenize(payload.texts).await {
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to tokenize text",
Some(e),
),
Ok(tokens) => (
Ok(res) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"tokens": tokens,
"tokens": res,
})),
}),
),
Expand Down Expand Up @@ -2609,6 +2675,7 @@ fn main() {
.route("/sqlite_workers", delete(sqlite_workers_delete))
// Misc
.route("/tokenize", post(tokenize))
.route("/tokenize/batch", post(tokenize_batch))

// Extensions
.layer(DefaultBodyLimit::disable())
Expand Down
17 changes: 9 additions & 8 deletions core/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
use crate::providers::chat_messages::{
AssistantChatMessage, ChatMessage, ContentBlock, MixedContent,
};
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::{ChatFunction, ChatFunctionCall};
use crate::providers::llm::{
ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, Tokens, LLM,
};
use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID};
use crate::providers::tiktoken::tiktoken::anthropic_base_singleton;
use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, decode_async, encode_async};
use crate::run::Credentials;
use crate::utils;
use crate::utils::ParseError;
Expand All @@ -22,10 +27,6 @@ use std::str::FromStr;
use std::time::Duration;
use tokio::sync::mpsc::UnboundedSender;

use super::chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent};
use super::llm::{ChatFunction, ChatFunctionCall};
use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
Expand Down Expand Up @@ -1524,8 +1525,8 @@ impl LLM for AnthropicLLM {
decode_async(anthropic_base_singleton(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(anthropic_base_singleton(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(anthropic_base_singleton(), texts).await
}

async fn chat(
Expand Down Expand Up @@ -1692,8 +1693,8 @@ impl Embedder for AnthropicEmbedder {
decode_async(anthropic_base_singleton(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(anthropic_base_singleton(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(anthropic_base_singleton(), texts).await
}

async fn embed(&self, _text: Vec<&str>, _extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Expand Down
13 changes: 6 additions & 7 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::providers::chat_messages::AssistantChatMessage;
use crate::providers::chat_messages::ChatMessage;
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::ChatFunction;
use crate::providers::llm::Tokens;
Expand All @@ -8,10 +9,10 @@ use crate::providers::openai::{
to_openai_messages, OpenAILLM, OpenAITool, OpenAIToolChoice,
};
use crate::providers::provider::{Provider, ProviderID};
use crate::providers::tiktoken::tiktoken::{batch_tokenize_async, decode_async, encode_async};
use crate::providers::tiktoken::tiktoken::{
cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE,
};
use crate::providers::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};
use crate::run::Credentials;
use crate::utils;
use anyhow::{anyhow, Result};
Expand All @@ -27,8 +28,6 @@ use std::str::FromStr;
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;

use super::chat_messages::ChatMessage;

#[derive(Serialize, Deserialize, Debug, Clone)]
struct AzureOpenAIScaleSettings {
scale_type: String,
Expand Down Expand Up @@ -238,8 +237,8 @@ impl LLM for AzureOpenAILLM {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn generate(
Expand Down Expand Up @@ -687,8 +686,8 @@ impl Embedder for AzureOpenAIEmbedder {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Expand Down
2 changes: 1 addition & 1 deletion core/src/providers/embedder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub trait Embedder {

async fn encode(&self, text: &str) -> Result<Vec<usize>>;
async fn decode(&self, tokens: Vec<usize>) -> Result<String>;
async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>>;
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>>;

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>>;
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/providers/google_ai_studio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use super::{
},
provider::{Provider, ProviderID},
tiktoken::tiktoken::{
cl100k_base_singleton, decode_async, encode_async, tokenize_async, CoreBPE,
batch_tokenize_async, cl100k_base_singleton, decode_async, encode_async, CoreBPE,
},
};

Expand Down Expand Up @@ -391,8 +391,8 @@ impl LLM for GoogleAiStudioLLM {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn generate(
Expand Down
5 changes: 2 additions & 3 deletions core/src/providers/llm.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::cached_request::CachedRequest;
use crate::project::Project;
use crate::providers::chat_messages::{AssistantChatMessage, ChatMessage};
use crate::providers::provider::{provider, with_retryable_back_off, ProviderID};
use crate::run::Credentials;
use crate::stores::store::Store;
Expand All @@ -13,8 +14,6 @@ use std::str::FromStr;
use tokio::sync::mpsc::UnboundedSender;
use tracing::{error, info};

use super::chat_messages::{AssistantChatMessage, ChatMessage};

#[derive(Debug, Serialize, PartialEq, Clone, Deserialize)]
pub struct Tokens {
pub text: String,
Expand Down Expand Up @@ -107,7 +106,7 @@ pub trait LLM {

async fn encode(&self, text: &str) -> Result<Vec<usize>>;
async fn decode(&self, tokens: Vec<usize>) -> Result<String>;
async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>>;
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>>;

async fn generate(
&self,
Expand Down
23 changes: 13 additions & 10 deletions core/src/providers/mistral.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use super::chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent};
use super::llm::{ChatFunction, ChatFunctionCall};
use super::sentencepiece::sentencepiece::{
decode_async, encode_async, mistral_instruct_tokenizer_240216_model_v2_base_singleton,
mistral_instruct_tokenizer_240216_model_v3_base_singleton,
mistral_tokenizer_model_v1_base_singleton, tokenize_async,
use crate::providers::chat_messages::{
AssistantChatMessage, ChatMessage, ContentBlock, MixedContent,
};
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::{ChatFunction, ChatFunctionCall};
use crate::providers::llm::{
ChatMessageRole, LLMChatGeneration, LLMGeneration, LLMTokenUsage, LLM,
};
use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID};
use crate::providers::sentencepiece::sentencepiece::{
batch_tokenize_async, decode_async, encode_async,
mistral_instruct_tokenizer_240216_model_v2_base_singleton,
mistral_instruct_tokenizer_240216_model_v3_base_singleton,
mistral_tokenizer_model_v1_base_singleton,
};
use crate::run::Credentials;
use crate::utils::{self, now, ParseError};
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -912,8 +915,8 @@ impl LLM for MistralAILLM {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn chat(
Expand Down Expand Up @@ -1145,8 +1148,8 @@ impl Embedder for MistralEmbedder {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn embed(&self, text: Vec<&str>, _extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Expand Down
17 changes: 9 additions & 8 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use crate::providers::chat_messages::{
AssistantChatMessage, ChatMessage, ContentBlock, MixedContent,
};
use crate::providers::embedder::{Embedder, EmbedderVector};
use crate::providers::llm::Tokens;
use crate::providers::llm::{ChatFunction, ChatFunctionCall};
Expand All @@ -6,9 +9,9 @@ use crate::providers::llm::{
};
use crate::providers::provider::{ModelError, ModelErrorRetryOptions, Provider, ProviderID};
use crate::providers::tiktoken::tiktoken::{
cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE,
batch_tokenize_async, cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE,
};
use crate::providers::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};
use crate::providers::tiktoken::tiktoken::{decode_async, encode_async};
use crate::run::Credentials;
use crate::utils;
use crate::utils::ParseError;
Expand All @@ -32,8 +35,6 @@ use std::time::Duration;
use tokio::sync::mpsc::UnboundedSender;
use tokio::time::timeout;

use super::chat_messages::{AssistantChatMessage, ChatMessage, ContentBlock, MixedContent};

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Usage {
pub prompt_tokens: u64,
Expand Down Expand Up @@ -1754,8 +1755,8 @@ impl LLM for OpenAILLM {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn generate(
Expand Down Expand Up @@ -2193,8 +2194,8 @@ impl Embedder for OpenAIEmbedder {
decode_async(self.tokenizer(), tokens).await
}

async fn tokenize(&self, text: &str) -> Result<Vec<(usize, String)>> {
tokenize_async(self.tokenizer(), text).await
async fn tokenize(&self, texts: Vec<String>) -> Result<Vec<Vec<(usize, String)>>> {
batch_tokenize_async(self.tokenizer(), texts).await
}

async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>> {
Expand Down
Loading

0 comments on commit 8168475

Please sign in to comment.