From d2efc6e848803e9149f4db8882ad5bb3ddaac4e8 Mon Sep 17 00:00:00 2001 From: Stanislas Polu Date: Sat, 13 Jan 2024 09:51:52 +0100 Subject: [PATCH] core: move to RwLock for BPE (#3188) --- core/src/providers/azure_openai.rs | 6 ++-- core/src/providers/google_vertex_ai.rs | 4 +-- core/src/providers/mistral.rs | 4 +-- core/src/providers/openai.rs | 6 ++-- core/src/providers/tiktoken/tiktoken.rs | 37 +++++++++++++------------ 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/core/src/providers/azure_openai.rs b/core/src/providers/azure_openai.rs index c4b84547419c..ed45bea29ca2 100644 --- a/core/src/providers/azure_openai.rs +++ b/core/src/providers/azure_openai.rs @@ -17,7 +17,7 @@ use hyper::header; use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request, Uri}; use hyper_tls::HttpsConnector; use itertools::izip; -use parking_lot::Mutex; +use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::io::prelude::*; @@ -179,7 +179,7 @@ impl AzureOpenAILLM { .parse::()?) } - fn tokenizer(&self) -> Arc> { + fn tokenizer(&self) -> Arc> { match self.model_id.as_ref() { Some(model_id) => match model_id.as_str() { "code_davinci-002" | "code-cushman-001" => p50k_base_singleton(), @@ -606,7 +606,7 @@ impl AzureOpenAIEmbedder { .parse::()?) } - fn tokenizer(&self) -> Arc> { + fn tokenizer(&self) -> Arc> { match self.model_id.as_ref() { Some(model_id) => match model_id.as_str() { "text-embedding-ada-002" => cl100k_base_singleton(), diff --git a/core/src/providers/google_vertex_ai.rs b/core/src/providers/google_vertex_ai.rs index 11dee13dbc4d..9d17d13e1db8 100644 --- a/core/src/providers/google_vertex_ai.rs +++ b/core/src/providers/google_vertex_ai.rs @@ -4,7 +4,7 @@ use eventsource_client as es; use eventsource_client::Client as ESClient; use futures::TryStreamExt; use hyper_tls::HttpsConnector; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::sync::Arc; @@ -255,7 +255,7 @@ impl GoogleVertexAiLLM { } } - fn tokenizer(&self) -> Arc> { + fn tokenizer(&self) -> Arc> { cl100k_base_singleton() } } diff --git a/core/src/providers/mistral.rs b/core/src/providers/mistral.rs index de17b5f57a73..77ac2dc77984 100644 --- a/core/src/providers/mistral.rs +++ b/core/src/providers/mistral.rs @@ -11,7 +11,7 @@ use eventsource_client::Client as ESClient; use futures::TryStreamExt; use hyper::{body::Buf, Body, Client, Method, Request, Uri}; use hyper_tls::HttpsConnector; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; @@ -198,7 +198,7 @@ impl MistralAILLM { mistral_messages } - fn tokenizer(&self) -> Arc> { + fn tokenizer(&self) -> Arc> { return p50k_base_singleton(); } diff --git a/core/src/providers/openai.rs b/core/src/providers/openai.rs index 06ed9934d2d5..b7d81e1e3b7b 100644 --- a/core/src/providers/openai.rs +++ b/core/src/providers/openai.rs @@ -15,7 +15,7 @@ use futures::TryStreamExt; use hyper::{body::Buf, Body, Client, Method, Request, Uri}; use hyper_tls::HttpsConnector; use itertools::izip; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value; @@ -1116,7 +1116,7 @@ impl OpenAILLM { Ok(format!("https://api.openai.com/v1/chat/completions",).parse::()?) } - fn tokenizer(&self) -> Arc> { + fn tokenizer(&self) -> Arc> { match self.id.as_str() { "code_davinci-002" | "code-cushman-001" => p50k_base_singleton(), "text-davinci-002" | "text-davinci-003" => p50k_base_singleton(), @@ -1534,7 +1534,7 @@ impl OpenAIEmbedder { Ok(format!("https://api.openai.com/v1/embeddings",).parse::()?) } - fn tokenizer(&self) -> Arc> { + fn tokenizer(&self) -> Arc> { match self.id.as_str() { "text-embedding-ada-002" => cl100k_base_singleton(), _ => r50k_base_singleton(), diff --git a/core/src/providers/tiktoken/tiktoken.rs b/core/src/providers/tiktoken/tiktoken.rs index d894461270c2..8936158bd31b 100644 --- a/core/src/providers/tiktoken/tiktoken.rs +++ b/core/src/providers/tiktoken/tiktoken.rs @@ -6,7 +6,7 @@ use anyhow::{anyhow, Result}; use base64::{engine::general_purpose, Engine as _}; use fancy_regex::Regex; use lazy_static::lazy_static; -use parking_lot::Mutex; +use parking_lot::RwLock; use rustc_hash::FxHashMap as HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -106,48 +106,49 @@ pub fn cl100k_base() -> Result { ) } -pub fn anthropic_base_singleton() -> Arc> { +pub fn anthropic_base_singleton() -> Arc> { lazy_static! { - static ref ANTHROPIC_BASE: Arc> = - Arc::new(Mutex::new(anthropic_base().unwrap())); + static ref ANTHROPIC_BASE: Arc> = + Arc::new(RwLock::new(anthropic_base().unwrap())); } ANTHROPIC_BASE.clone() } -pub fn r50k_base_singleton() -> Arc> { +pub fn r50k_base_singleton() -> Arc> { lazy_static! { - static ref R50K_BASE: Arc> = Arc::new(Mutex::new(r50k_base().unwrap())); + static ref R50K_BASE: Arc> = Arc::new(RwLock::new(r50k_base().unwrap())); } R50K_BASE.clone() } -pub fn p50k_base_singleton() -> Arc> { +pub fn p50k_base_singleton() -> Arc> { lazy_static! { - static ref P50K_BASE: Arc> = Arc::new(Mutex::new(p50k_base().unwrap())); + static ref P50K_BASE: Arc> = Arc::new(RwLock::new(p50k_base().unwrap())); } P50K_BASE.clone() } -pub fn cl100k_base_singleton() -> Arc> { +pub fn cl100k_base_singleton() -> Arc> { lazy_static! { - static ref CL100K_BASE: Arc> = Arc::new(Mutex::new(cl100k_base().unwrap())); + static ref CL100K_BASE: Arc> = + Arc::new(RwLock::new(cl100k_base().unwrap())); } CL100K_BASE.clone() } -pub async fn decode_async(bpe: Arc>, tokens: Vec) -> Result { - task::spawn_blocking(move || bpe.lock().decode(tokens)).await? +pub async fn decode_async(bpe: Arc>, tokens: Vec) -> Result { + task::spawn_blocking(move || bpe.read().decode(tokens)).await? } -pub async fn encode_async(bpe: Arc>, text: &str) -> Result> { +pub async fn encode_async(bpe: Arc>, text: &str) -> Result> { let text = text.to_string(); - let r = task::spawn_blocking(move || bpe.lock().encode_with_special_tokens(&text)).await?; + let r = task::spawn_blocking(move || bpe.read().encode_with_special_tokens(&text)).await?; Ok(r) } -pub async fn tokenize_async(bpe: Arc>, text: &str) -> Result> { +pub async fn tokenize_async(bpe: Arc>, text: &str) -> Result> { let text = text.to_string(); - let r = task::spawn_blocking(move || bpe.lock().tokenize(&text)).await?; + let r = task::spawn_blocking(move || bpe.read().tokenize(&text)).await?; Ok(r) } @@ -818,7 +819,7 @@ mod tests { // println!("p50k_base_singleton load 1: {:?}", now.elapsed()); // let now = std::time::Instant::now(); { - let guard = bpe1.lock(); + let guard = bpe1.read(); let tokens = guard.encode_with_special_tokens("This is a test with a lot of spaces"); guard.decode(tokens.clone()).unwrap(); @@ -830,7 +831,7 @@ mod tests { // println!("p50k_base_singleton load 2: {:?}", now.elapsed()); // let now = std::time::Instant::now(); { - let guard = bpe2.lock(); + let guard = bpe2.read(); let tokens = guard.encode_with_special_tokens("This is a test with a lot of spaces"); guard.decode(tokens.clone()).unwrap();