Skip to content

Commit

Permalink
core: move to RwLock for BPE (#3188)
Browse files Browse the repository at this point in the history
  • Loading branch information
spolu authored Jan 13, 2024
1 parent 5afb283 commit d2efc6e
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 28 deletions.
6 changes: 3 additions & 3 deletions core/src/providers/azure_openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use hyper::header;
use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request, Uri};
use hyper_tls::HttpsConnector;
use itertools::izip;
use parking_lot::Mutex;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io::prelude::*;
Expand Down Expand Up @@ -179,7 +179,7 @@ impl AzureOpenAILLM {
.parse::<Uri>()?)
}

fn tokenizer(&self) -> Arc<Mutex<CoreBPE>> {
fn tokenizer(&self) -> Arc<RwLock<CoreBPE>> {
match self.model_id.as_ref() {
Some(model_id) => match model_id.as_str() {
"code_davinci-002" | "code-cushman-001" => p50k_base_singleton(),
Expand Down Expand Up @@ -606,7 +606,7 @@ impl AzureOpenAIEmbedder {
.parse::<Uri>()?)
}

fn tokenizer(&self) -> Arc<Mutex<CoreBPE>> {
fn tokenizer(&self) -> Arc<RwLock<CoreBPE>> {
match self.model_id.as_ref() {
Some(model_id) => match model_id.as_str() {
"text-embedding-ada-002" => cl100k_base_singleton(),
Expand Down
4 changes: 2 additions & 2 deletions core/src/providers/google_vertex_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use eventsource_client as es;
use eventsource_client::Client as ESClient;
use futures::TryStreamExt;
use hyper_tls::HttpsConnector;
use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::Arc;
Expand Down Expand Up @@ -255,7 +255,7 @@ impl GoogleVertexAiLLM {
}
}

fn tokenizer(&self) -> Arc<Mutex<CoreBPE>> {
fn tokenizer(&self) -> Arc<RwLock<CoreBPE>> {
cl100k_base_singleton()
}
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/providers/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use eventsource_client::Client as ESClient;
use futures::TryStreamExt;
use hyper::{body::Buf, Body, Client, Method, Request, Uri};
use hyper_tls::HttpsConnector;
use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
Expand Down Expand Up @@ -198,7 +198,7 @@ impl MistralAILLM {
mistral_messages
}

fn tokenizer(&self) -> Arc<Mutex<CoreBPE>> {
fn tokenizer(&self) -> Arc<RwLock<CoreBPE>> {
return p50k_base_singleton();
}

Expand Down
6 changes: 3 additions & 3 deletions core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use futures::TryStreamExt;
use hyper::{body::Buf, Body, Client, Method, Request, Uri};
use hyper_tls::HttpsConnector;
use itertools::izip;
use parking_lot::Mutex;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
Expand Down Expand Up @@ -1116,7 +1116,7 @@ impl OpenAILLM {
Ok(format!("https://api.openai.com/v1/chat/completions",).parse::<Uri>()?)
}

fn tokenizer(&self) -> Arc<Mutex<CoreBPE>> {
fn tokenizer(&self) -> Arc<RwLock<CoreBPE>> {
match self.id.as_str() {
"code_davinci-002" | "code-cushman-001" => p50k_base_singleton(),
"text-davinci-002" | "text-davinci-003" => p50k_base_singleton(),
Expand Down Expand Up @@ -1534,7 +1534,7 @@ impl OpenAIEmbedder {
Ok(format!("https://api.openai.com/v1/embeddings",).parse::<Uri>()?)
}

fn tokenizer(&self) -> Arc<Mutex<CoreBPE>> {
fn tokenizer(&self) -> Arc<RwLock<CoreBPE>> {
match self.id.as_str() {
"text-embedding-ada-002" => cl100k_base_singleton(),
_ => r50k_base_singleton(),
Expand Down
37 changes: 19 additions & 18 deletions core/src/providers/tiktoken/tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use anyhow::{anyhow, Result};
use base64::{engine::general_purpose, Engine as _};
use fancy_regex::Regex;
use lazy_static::lazy_static;
use parking_lot::Mutex;
use parking_lot::RwLock;
use rustc_hash::FxHashMap as HashMap;
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -106,48 +106,49 @@ pub fn cl100k_base() -> Result<CoreBPE> {
)
}

pub fn anthropic_base_singleton() -> Arc<Mutex<CoreBPE>> {
pub fn anthropic_base_singleton() -> Arc<RwLock<CoreBPE>> {
lazy_static! {
static ref ANTHROPIC_BASE: Arc<Mutex<CoreBPE>> =
Arc::new(Mutex::new(anthropic_base().unwrap()));
static ref ANTHROPIC_BASE: Arc<RwLock<CoreBPE>> =
Arc::new(RwLock::new(anthropic_base().unwrap()));
}
ANTHROPIC_BASE.clone()
}

pub fn r50k_base_singleton() -> Arc<Mutex<CoreBPE>> {
pub fn r50k_base_singleton() -> Arc<RwLock<CoreBPE>> {
lazy_static! {
static ref R50K_BASE: Arc<Mutex<CoreBPE>> = Arc::new(Mutex::new(r50k_base().unwrap()));
static ref R50K_BASE: Arc<RwLock<CoreBPE>> = Arc::new(RwLock::new(r50k_base().unwrap()));
}
R50K_BASE.clone()
}

pub fn p50k_base_singleton() -> Arc<Mutex<CoreBPE>> {
pub fn p50k_base_singleton() -> Arc<RwLock<CoreBPE>> {
lazy_static! {
static ref P50K_BASE: Arc<Mutex<CoreBPE>> = Arc::new(Mutex::new(p50k_base().unwrap()));
static ref P50K_BASE: Arc<RwLock<CoreBPE>> = Arc::new(RwLock::new(p50k_base().unwrap()));
}
P50K_BASE.clone()
}

pub fn cl100k_base_singleton() -> Arc<Mutex<CoreBPE>> {
pub fn cl100k_base_singleton() -> Arc<RwLock<CoreBPE>> {
lazy_static! {
static ref CL100K_BASE: Arc<Mutex<CoreBPE>> = Arc::new(Mutex::new(cl100k_base().unwrap()));
static ref CL100K_BASE: Arc<RwLock<CoreBPE>> =
Arc::new(RwLock::new(cl100k_base().unwrap()));
}
CL100K_BASE.clone()
}

pub async fn decode_async(bpe: Arc<Mutex<CoreBPE>>, tokens: Vec<usize>) -> Result<String> {
task::spawn_blocking(move || bpe.lock().decode(tokens)).await?
pub async fn decode_async(bpe: Arc<RwLock<CoreBPE>>, tokens: Vec<usize>) -> Result<String> {
task::spawn_blocking(move || bpe.read().decode(tokens)).await?
}

pub async fn encode_async(bpe: Arc<Mutex<CoreBPE>>, text: &str) -> Result<Vec<usize>> {
pub async fn encode_async(bpe: Arc<RwLock<CoreBPE>>, text: &str) -> Result<Vec<usize>> {
let text = text.to_string();
let r = task::spawn_blocking(move || bpe.lock().encode_with_special_tokens(&text)).await?;
let r = task::spawn_blocking(move || bpe.read().encode_with_special_tokens(&text)).await?;
Ok(r)
}

pub async fn tokenize_async(bpe: Arc<Mutex<CoreBPE>>, text: &str) -> Result<Vec<(usize, String)>> {
pub async fn tokenize_async(bpe: Arc<RwLock<CoreBPE>>, text: &str) -> Result<Vec<(usize, String)>> {
let text = text.to_string();
let r = task::spawn_blocking(move || bpe.lock().tokenize(&text)).await?;
let r = task::spawn_blocking(move || bpe.read().tokenize(&text)).await?;
Ok(r)
}

Expand Down Expand Up @@ -818,7 +819,7 @@ mod tests {
// println!("p50k_base_singleton load 1: {:?}", now.elapsed());
// let now = std::time::Instant::now();
{
let guard = bpe1.lock();
let guard = bpe1.read();
let tokens =
guard.encode_with_special_tokens("This is a test with a lot of spaces");
guard.decode(tokens.clone()).unwrap();
Expand All @@ -830,7 +831,7 @@ mod tests {
// println!("p50k_base_singleton load 2: {:?}", now.elapsed());
// let now = std::time::Instant::now();
{
let guard = bpe2.lock();
let guard = bpe2.read();
let tokens =
guard.encode_with_special_tokens("This is a test with a lot of spaces");
guard.decode(tokens.clone()).unwrap();
Expand Down

0 comments on commit d2efc6e

Please sign in to comment.