From 89d2b7ef6935d7a843ebd1e8f7c6c2870e8e01f0 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Mon, 2 Dec 2024 16:16:18 -0500 Subject: [PATCH 01/10] init --- router/src/lib.rs | 14 ++++++++++++++ router/src/server.rs | 1 + server/lorax_server/models/flash_bert.py | 2 ++ 3 files changed, 17 insertions(+) diff --git a/router/src/lib.rs b/router/src/lib.rs index 05cb07e96..763b39248 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -664,6 +664,20 @@ impl From for TextMessage { } } +// #[derive(Clone, Debug, Deserialize, ToSchema)] +// pub struct OpenAIEmbedRequest { +// pub input: String, +// pub model: String, +// pub encoding_format: Option, +// } + +// #[derive(Clone, Debug, Deserialize, ToSchema)] +// pub struct OpenAIEmbedResponse { +// pub index: i32, +// pub embedding: Vec, +// pub object: String, +// } + #[derive(Clone, Debug, Deserialize, ToSchema)] struct ChatCompletionRequest { model: String, diff --git a/router/src/server.rs b/router/src/server.rs index d33412aaf..a3e46ba95 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1484,6 +1484,7 @@ pub async fn run( .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) .route("/v1/chat/completions", post(chat_completions_v1)) + .route("/v1/embeddings", post(embed)) // AWS Sagemaker route .route("/invocations", post(compat_generate)); diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 8c92085b6..5b70f50f8 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -150,6 +150,8 @@ def supports_classification(self) -> bool: def warmup(self, batch: FlashEmbeddingClassificationBatch, max_new_tokens: int) -> int | None: # Note: This is meant to 1) preallocate the memory by doing a forward pass # and then just returning the max seqlen since for embeddings we are never generating + # print(batch) + breakpoint() _ = self.embed(batch) return batch.max_s From 2711877747232817589cde76bda56b019a88b407 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Mon, 2 Dec 2024 19:54:32 -0500 Subject: [PATCH 02/10] need to fix uae model --- Cargo.lock | 12 +++++++ integration-tests/test_embeddings.py | 13 +++++++ router/Cargo.toml | 2 +- router/src/lib.rs | 26 +++++++------- router/src/server.rs | 46 ++++++++++++++++++++---- server/lorax_server/models/flash_bert.py | 2 -- 6 files changed, 77 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3e7eec1f7..3312d84ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -277,6 +277,7 @@ checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" dependencies = [ "async-trait", "axum-core 0.4.5", + "axum-macros", "bytes", "futures-util", "http 1.1.0", @@ -341,6 +342,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "axum-tracing-opentelemetry" version = "0.16.0" diff --git a/integration-tests/test_embeddings.py b/integration-tests/test_embeddings.py index 329a8810b..f34b4eb58 100644 --- a/integration-tests/test_embeddings.py +++ b/integration-tests/test_embeddings.py @@ -13,3 +13,16 @@ def test_stella_1_5b(): response.raise_for_status() print("RESPONSE FROM EMBEDDING: ", response.json()) assert len(response.json()["embeddings"]) > 0 + + +def test_uae_large_v1_1_5b(): + config = { + "name": "UAE-Large-V1-1.5b", + "model_id": "WhereIsAI/UAE-Large-V1", + "docker_args": {"max_input_length": 512}, + } + with run_lorax_container(config): + response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"}) + response.raise_for_status() + print("RESPONSE FROM EMBEDDING: ", response.json()) + assert len(response.json()["embeddings"]) > 0 diff --git a/router/Cargo.toml b/router/Cargo.toml index c325ba9dc..084b2f589 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -16,7 +16,7 @@ path = "src/main.rs" [dependencies] async-stream = "0.3.3" -axum = { version = "0.7", features = ["json"] } +axum = { version = "0.7", features = ["json", "macros"] } axum-tracing-opentelemetry = "0.16" clap = { version = "4.1.4", features = ["derive", "env"] } futures = "0.3.26" diff --git a/router/src/lib.rs b/router/src/lib.rs index 763b39248..173aa1520 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -664,20 +664,6 @@ impl From for TextMessage { } } -// #[derive(Clone, Debug, Deserialize, ToSchema)] -// pub struct OpenAIEmbedRequest { -// pub input: String, -// pub model: String, -// pub encoding_format: Option, -// } - -// #[derive(Clone, Debug, Deserialize, ToSchema)] -// pub struct OpenAIEmbedResponse { -// pub index: i32, -// pub embedding: Vec, -// pub object: String, -// } - #[derive(Clone, Debug, Deserialize, ToSchema)] struct ChatCompletionRequest { model: String, @@ -1179,6 +1165,18 @@ struct EmbedResponse { embeddings: Vec, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct CompatEmbedRequest { + inputs: String, + #[serde(default = "default_embed_parameters")] + pub parameters: EmbedParameters, +} + +#[derive(Serialize, ToSchema)] +struct CompatEmbedResponse { + embeddings: Vec, +} + #[derive(Clone, Debug, Deserialize, ToSchema)] struct ClassifyRequest { inputs: String, diff --git a/router/src/server.rs b/router/src/server.rs index a3e46ba95..1c6e0dd19 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -10,11 +10,12 @@ use crate::{ AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, - CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse, - CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details, - EmbedParameters, EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, - LogProbs, Message, OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, + CompatEmbedRequest, CompatEmbedResponse, CompatGenerateRequest, CompletionFinishReason, + CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, Details, EmbedParameters, + EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters, + GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, + OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, ReturnFunctionDefinition, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, }; @@ -1483,8 +1484,8 @@ pub async fn run( .route("/classify_batch", post(classify_batch)) .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) + .route("/v1/embeddings", post(compat_embed)) .route("/v1/chat/completions", post(chat_completions_v1)) - .route("/v1/embeddings", post(embed)) // AWS Sagemaker route .route("/invocations", post(compat_generate)); @@ -1626,7 +1627,7 @@ impl From for Event { post, tag = "Embedding", path = "/embed", - request_body = TokenizeRequest, + request_body = EmbedRequest, responses( (status = 200, description = "Embeddings ids", body = EmbedResponse), (status = 500, description = "Incomplete embedding", body = ErrorResponse), @@ -1644,6 +1645,37 @@ async fn embed( Ok(Json(response)) } +/// Embed inputs +#[utoipa::path( + post, + tag = "Embedding", + path = "/embed", + request_body = CompatEmbedRequest, + responses( + (status = 200, description = "Embeddings ids", body = CompatEmbedResponse), + (status = 500, description = "Incomplete embedding", body = ErrorResponse), + ) +)] +#[instrument(skip_all)] +#[axum::debug_handler] +async fn compat_embed( + infer: Extension, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + metrics::increment_counter!("lorax_request_count"); + tracing::debug!("Input: {}", req.inputs); + // Inference + let embed_req = EmbedRequest { + inputs: req.inputs, + parameters: req.parameters, + }; + let response = infer.embed(embed_req).await?; + let compat_response = CompatEmbedResponse { + embeddings: response.embeddings, + }; + Ok(Json(compat_response)) +} + #[utoipa::path( post, tag = "Classify", diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 5b70f50f8..8c92085b6 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -150,8 +150,6 @@ def supports_classification(self) -> bool: def warmup(self, batch: FlashEmbeddingClassificationBatch, max_new_tokens: int) -> int | None: # Note: This is meant to 1) preallocate the memory by doing a forward pass # and then just returning the max seqlen since for embeddings we are never generating - # print(batch) - breakpoint() _ = self.embed(batch) return batch.max_s From 6f28014fe773ad9148d218e71b9a1e45e73ef36c Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 3 Dec 2024 12:50:19 -0500 Subject: [PATCH 03/10] openai compat --- router/src/infer.rs | 139 +++++++++++++++++++++++++++++++++++++++---- router/src/lib.rs | 44 +++++++++++++- router/src/server.rs | 70 ++++++++++++++++------ 3 files changed, 220 insertions(+), 33 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index 101360ba1..b60adca34 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -4,9 +4,10 @@ use crate::queue::AdapterEvent; use crate::scheduler::AdapterScheduler; use crate::validation::{Validation, ValidationError}; use crate::{ - AdapterParameters, AlternativeToken, BatchClassifyRequest, ChatTemplateVersions, - ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, HubTokenizerConfig, Message, - MessageChunk, MessageContent, TextMessage, Token, TokenizerConfigToken, Tool, + AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest, + ChatTemplateVersions, ClassifyRequest, EmbedRequest, EmbedResponse, Entity, Entry, + HubTokenizerConfig, Message, MessageChunk, MessageContent, TextMessage, Token, + TokenizerConfigToken, Tool, }; use crate::{GenerateRequest, PrefillToken}; use futures::future::try_join_all; @@ -596,6 +597,7 @@ impl Infer { embedding, start: _, queued: _, + id: _, } => { return_embeddings = Some(embedding.values); } @@ -857,6 +859,127 @@ impl Infer { } } + #[instrument(skip(self))] + pub(crate) async fn embed_batch( + &self, + request: BatchEmbedRequest, + ) -> Result, InferError> { + // Limit concurrent requests by acquiring a permit from the semaphore + let _permit = self + .clone() + .limit_concurrent_requests + .try_acquire_owned() + .map_err(|err| { + metrics::increment_counter!("lorax_request_failure", "err" => "overloaded"); + tracing::error!("{err}"); + err + })?; + + let adapter = Adapter::new( + AdapterParameters { + adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()], + ..Default::default() + }, + "hub".to_string(), + 0, + None, + ); + + // MPSC channel to communicate with the background batching task + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + let request_id_map: HashMap = request + .inputs + .iter() + .enumerate() + .map(|(id, input)| (id as u64, input.clone())) + .collect(); + + // Call validate_input on every input in the request and await the results + let futures: Vec<_> = request + .inputs + .iter() + .map(|input| { + self.validation + .validate_input(input.clone(), true, None, Some(1)) + }) + .collect(); + + let all_tokenized_inputs = try_join_all(futures).await?; + + for ((id, r_inputs), (tokenized_inputs, input_length)) in + request.inputs.iter().enumerate().zip(all_tokenized_inputs) + { + let inputs = r_inputs.to_string().clone(); + let valid_request = ValidEmbedRequest { + inputs, + tokenized_inputs, + input_length: input_length as u32, + adapter: adapter.clone(), + }; + + // Process the request by sending it to the queue associated with `adapter` + self.adapter_scheduler.process( + adapter.clone(), + Entry { + request: Arc::new(valid_request), + response_tx: response_tx.clone(), + span: Span::current(), + temp_span: None, + queue_time: Instant::now(), + batch_time: None, + block_allocation: None, + id: Some(id as u64), + }, + ); + } + + drop(response_tx); // Close the sending end + + // Return values + + let mut all_embeddings = HashMap::new(); + let mut stream = UnboundedReceiverStream::new(response_rx); + while let Some(response) = stream.next().await { + match response? { + InferStreamResponse::Embed { + embedding, + start: _, + queued: _, + id, + } => { + all_embeddings.insert( + id.unwrap(), + EmbedResponse { + embeddings: embedding.values, + }, + ); + } + _ => { + tracing::error!( + "Received unexpected message type in classify_batch. This is a bug." + ); + } + } + } + if all_embeddings.is_empty() { + let err = InferError::EmbeddingFailure; + metrics::increment_counter!("lorax_request_failure", "err" => "embedding_failure"); + tracing::error!("{err}"); + Err(err) + } else { + let mut sorted_responses: Vec<_> = all_embeddings.into_iter().collect(); + sorted_responses.sort_by_key(|&(id, _)| id); + + let sorted_responses: Vec = sorted_responses + .into_iter() + .map(|(_, response)| response) + .collect(); + + Ok(sorted_responses) + } + } + /// Add best_of new requests to the queue and return a InferResponse of the sequence with /// the highest log probability per token #[instrument(skip(self))] @@ -1478,6 +1601,7 @@ fn send_embeddings( embedding: embedding.clone(), queued: entry.queue_time, start: entry.batch_time.unwrap(), + id: entry.id, }))?; // TODO(travis): redundant as we always return true, just make it return nothing @@ -1536,20 +1660,13 @@ pub(crate) enum InferStreamResponse { // Embeddings Embed { embedding: Embedding, - // For now allow this field even though it is unused. - // TODO:(magdy) enable tracing for these requests - #[allow(dead_code)] start: Instant, - #[allow(dead_code)] queued: Instant, + id: Option, // to support batching }, Classify { predictions: ClassifyPredictionList, - // For now allow this field even though it is unused. - // TODO:(magdy) enable tracing for these requests - #[allow(dead_code)] start: Instant, - #[allow(dead_code)] queued: Instant, id: Option, // to support batching }, diff --git a/router/src/lib.rs b/router/src/lib.rs index 173aa1520..fd5b1049f 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1165,16 +1165,47 @@ struct EmbedResponse { embeddings: Vec, } +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +enum StringOrVec { + String(String), + Vec(Vec), +} + +impl std::fmt::Display for StringOrVec { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + StringOrVec::String(s) => write!(f, "{}", s), + StringOrVec::Vec(v) => write!(f, "{}", v.join(", ")), + } + } +} + #[derive(Clone, Debug, Deserialize, ToSchema)] struct CompatEmbedRequest { - inputs: String, + input: StringOrVec, + #[allow(dead_code)] + model: String, + #[allow(dead_code)] + encoding_format: Option, + #[allow(dead_code)] + dimensions: Option, + #[allow(dead_code)] + user: Option, #[serde(default = "default_embed_parameters")] - pub parameters: EmbedParameters, + parameters: EmbedParameters, } #[derive(Serialize, ToSchema)] struct CompatEmbedResponse { - embeddings: Vec, + embeddings: Vec, +} + +#[derive(Serialize, ToSchema)] +struct CompatEmbedding { + index: i32, + embedding: Vec, + object: String, } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -1187,6 +1218,13 @@ struct BatchClassifyRequest { inputs: Vec, } +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct BatchEmbedRequest { + inputs: Vec, + #[serde(default = "default_embed_parameters")] + parameters: EmbedParameters, +} + #[derive(Debug, Serialize, Deserialize)] struct Entity { entity_group: String, diff --git a/router/src/server.rs b/router/src/server.rs index 1c6e0dd19..b874b6b9a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -7,17 +7,17 @@ use crate::tool_grammar::ToolGrammar; use crate::validation::ValidationError; use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use crate::{ - AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence, + AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchEmbedRequest, BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, - CompatEmbedRequest, CompatEmbedResponse, CompatGenerateRequest, CompletionFinishReason, - CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompatEmbedRequest, CompatEmbedResponse, CompatEmbedding, CompatGenerateRequest, + CompletionFinishReason, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details, EmbedParameters, EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, - ReturnFunctionDefinition, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeRequest, - TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, + ReturnFunctionDefinition, SimpleToken, StreamDetails, StreamResponse, StringOrVec, Token, + TokenizeRequest, TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -1648,8 +1648,8 @@ async fn embed( /// Embed inputs #[utoipa::path( post, - tag = "Embedding", - path = "/embed", + tag = "OpenAI Compatible", + path = "/v1/embeddings", request_body = CompatEmbedRequest, responses( (status = 200, description = "Embeddings ids", body = CompatEmbedResponse), @@ -1663,17 +1663,49 @@ async fn compat_embed( Json(req): Json, ) -> Result, (StatusCode, Json)> { metrics::increment_counter!("lorax_request_count"); - tracing::debug!("Input: {}", req.inputs); - // Inference - let embed_req = EmbedRequest { - inputs: req.inputs, - parameters: req.parameters, - }; - let response = infer.embed(embed_req).await?; - let compat_response = CompatEmbedResponse { - embeddings: response.embeddings, - }; - Ok(Json(compat_response)) + tracing::debug!("Input: {}", req.input); + if let StringOrVec::Vec(inputs) = req.input { + let batch_embed_req = BatchEmbedRequest { + inputs, + parameters: req.parameters, + }; + let response = infer.embed_batch(batch_embed_req).await?; + let compat_embeddings = response + .into_iter() + .enumerate() + .map(|(i, e)| -> CompatEmbedding { + CompatEmbedding { + index: i as i32, + embedding: e.embeddings, + object: "embedding".to_string(), + } + }) + .collect(); + Ok(Json(CompatEmbedResponse { + embeddings: compat_embeddings, + })) + } else if let StringOrVec::String(input) = req.input { + let embed_req = EmbedRequest { + inputs: input.to_string(), + parameters: req.parameters, + }; + let response = infer.embed(embed_req).await?; + Ok(Json(CompatEmbedResponse { + embeddings: vec![CompatEmbedding { + index: 0, + embedding: response.embeddings, + object: "embedding".to_string(), + }], + })) + } else { + Err(( + StatusCode::BAD_REQUEST, + Json(ErrorResponse { + error: "Invalid input".to_string(), + error_type: "invalid_input".to_string(), + }), + )) + } } #[utoipa::path( @@ -1749,7 +1781,7 @@ async fn classify( post, tag = "ClassifyBatch", path = "/classify_batch", - request_body = TokenizeRequest, + request_body = BatchClassifyRequest, responses( (status = 200, description = "Classifications", body = BatchClassifyResponse), (status = 500, description = "Incomplete classification", body = ErrorResponse), From 896a34c16c828410c91d83009b5e87eeff9b09a4 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 3 Dec 2024 15:37:25 -0500 Subject: [PATCH 04/10] clean up warnings --- router/src/infer.rs | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index b60adca34..d89090803 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -875,26 +875,36 @@ impl Infer { err })?; + let (adapter_source, adapter_parameters) = extract_adapter_params( + request.parameters.adapter_id.clone(), + request.parameters.adapter_source.clone(), + request.parameters.adapter_parameters.clone(), + ); + + let adapter_idx; + { + // TODO(travis): can optimize concurrency here using RWLock + let mut adapter_to_index = self.adapter_to_index.lock().await; + let adapter_key = adapter_parameters.clone(); + if adapter_to_index.contains_key(&adapter_key) { + adapter_idx = *adapter_to_index.get(&adapter_key).unwrap(); + } else { + adapter_idx = adapter_to_index.len() as u32; + adapter_to_index.insert(adapter_key, adapter_idx); + } + } + + let api_token = request.parameters.api_token.clone(); let adapter = Adapter::new( - AdapterParameters { - adapter_ids: vec![BASE_MODEL_ADAPTER_ID.to_string()], - ..Default::default() - }, - "hub".to_string(), - 0, - None, + adapter_parameters, + adapter_source.unwrap(), + adapter_idx, + api_token, ); // MPSC channel to communicate with the background batching task let (response_tx, response_rx) = mpsc::unbounded_channel(); - let request_id_map: HashMap = request - .inputs - .iter() - .enumerate() - .map(|(id, input)| (id as u64, input.clone())) - .collect(); - // Call validate_input on every input in the request and await the results let futures: Vec<_> = request .inputs @@ -1658,9 +1668,12 @@ pub(crate) enum InferStreamResponse { // Intermediate messages Token(Token), // Embeddings + // TODO: add tracing for embedding Embed { embedding: Embedding, + #[allow(dead_code)] start: Instant, + #[allow(dead_code)] queued: Instant, id: Option, // to support batching }, From 94687f3b3731f04dff093cc56dcb10667979308a Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 3 Dec 2024 15:58:15 -0500 Subject: [PATCH 05/10] add classification tests --- integration-tests/test_classifications.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 integration-tests/test_classifications.py diff --git a/integration-tests/test_classifications.py b/integration-tests/test_classifications.py new file mode 100644 index 000000000..a6e9c3716 --- /dev/null +++ b/integration-tests/test_classifications.py @@ -0,0 +1,18 @@ +import requests +from utils.docker_runner import run_lorax_container + + +def test_distilbert_ner(): + config = { + "name": "distilbert-ner", + "model_id": "dslim/distilbert-NER", + "docker_args": {"max_input_length": 512}, + } + with run_lorax_container(config): + response = requests.post( + "http://localhost:8080/classify", + json={"inputs": "Johnny supports the Golden State Warriors. He lives in London."}, + ) + response.raise_for_status() + print("RESPONSE FROM CLASSIFICATION: ", response.json()) + assert len(response.json()["predictions"]) > 0 From 8e8591f9b2ab7624958fc5ce191825ff0869707b Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 3 Dec 2024 17:46:57 -0500 Subject: [PATCH 06/10] remove check --- router/src/main.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/router/src/main.rs b/router/src/main.rs index 249031258..42d0c65a1 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -135,13 +135,6 @@ async fn main() -> Result<(), RouterError> { init_logging(otlp_endpoint, json_output); - // Validate args - if max_input_length >= max_total_tokens { - return Err(RouterError::ArgumentValidation( - "`max_input_length` must be < `max_total_tokens`".to_string(), - )); - } - if validation_workers == 0 { return Err(RouterError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), From 8e10f6ed70d7b4b2ee7795885d0112f0a2f4338c Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 3 Dec 2024 17:49:19 -0500 Subject: [PATCH 07/10] add new tests --- .github/workflows/integration-tests/action.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/integration-tests/action.yaml b/.github/workflows/integration-tests/action.yaml index c7492d46d..854efe238 100644 --- a/.github/workflows/integration-tests/action.yaml +++ b/.github/workflows/integration-tests/action.yaml @@ -66,6 +66,12 @@ runs: cd integration-tests pytest test_embeddings.py -vv --capture=tee-sys --log-cli-level=INFO + - name: Run Classification tests + shell: bash + run: | + cd integration-tests + pytest test_classifications.py -vv --capture=tee-sys --log-cli-level=INFO + - name: Run LLM tests shell: bash run: | From ffde1d17fe61a676639c276e0dd25200369fd153 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Tue, 3 Dec 2024 18:12:06 -0500 Subject: [PATCH 08/10] remove check --- launcher/src/main.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/launcher/src/main.rs b/launcher/src/main.rs index 53929a3cb..33116d167 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -1675,13 +1675,6 @@ fn main() -> Result<(), LauncherError> { } }; - // Validate args - if max_input_tokens >= max_total_tokens { - return Err(LauncherError::ArgumentValidation( - "`max_input_length` must be < `max_total_tokens`".to_string(), - )); - } - if args.validation_workers == 0 { return Err(LauncherError::ArgumentValidation( "`validation_workers` must be > 0".to_string(), From 40db404087126a357884a6061efde971cc582078 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Wed, 4 Dec 2024 10:24:49 -0500 Subject: [PATCH 09/10] fix test --- integration-tests/test_embeddings.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integration-tests/test_embeddings.py b/integration-tests/test_embeddings.py index f34b4eb58..71723da52 100644 --- a/integration-tests/test_embeddings.py +++ b/integration-tests/test_embeddings.py @@ -19,7 +19,12 @@ def test_uae_large_v1_1_5b(): config = { "name": "UAE-Large-V1-1.5b", "model_id": "WhereIsAI/UAE-Large-V1", - "docker_args": {"max_input_length": 512}, + "docker_args": { + "max_input_length": 512, + "max_batch_prefill_tokens": 512, + "max_batch_total_tokens": 512, + "max_total_tokens": 512, + }, } with run_lorax_container(config): response = requests.post("http://localhost:8080/embed", json={"inputs": "Hello, world!"}) From c9a25818a7872e0054ced0ae4ac682a16c64be41 Mon Sep 17 00:00:00 2001 From: Magdy Saleh Date: Wed, 4 Dec 2024 16:11:23 -0500 Subject: [PATCH 10/10] fix second test --- integration-tests/test_classifications.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/integration-tests/test_classifications.py b/integration-tests/test_classifications.py index a6e9c3716..85b23285d 100644 --- a/integration-tests/test_classifications.py +++ b/integration-tests/test_classifications.py @@ -6,7 +6,12 @@ def test_distilbert_ner(): config = { "name": "distilbert-ner", "model_id": "dslim/distilbert-NER", - "docker_args": {"max_input_length": 512}, + "docker_args": { + "max_input_length": 512, + "max_batch_prefill_tokens": 512, + "max_batch_total_tokens": 512, + "max_total_tokens": 512, + }, } with run_lorax_container(config): response = requests.post(