diff --git a/proto/generate.proto b/proto/generate.proto index 916a1716b..f91812fc9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -17,7 +17,7 @@ service LoraxService { rpc Prefill (PrefillRequest) returns (PrefillResponse); /// Embed rpc Embed (EmbedRequest) returns (EmbedResponse); - /// Classify + /// Classify rpc Classify (ClassifyRequest) returns (ClassifyResponse); /// Decode token for a list of prefilled batches rpc Decode (DecodeRequest) returns (DecodeResponse); @@ -58,6 +58,7 @@ message InfoResponse { bool supports_embeddings = 9; bool supports_classification = 10; bool chunked_prefill = 11; + bool requires_block_allocator = 12; } /// Empty request @@ -118,11 +119,11 @@ message StoppingCriteriaParameters { message Image { /// Binary image data. bytes data = 1; - + /// Image MIME type. string mimetype = 2; } - + message InputChunk { oneof chunk { /// Plain text data @@ -331,7 +332,7 @@ message Entity { message EntityList { /// Request ID uint64 request_id = 1; - /// Entities + /// Entities repeated Entity entities = 2; /// XXX repeated uint32 input_ids = 4; diff --git a/router/src/infer.rs b/router/src/infer.rs index 2833cd485..a142793d8 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -188,7 +188,7 @@ impl Infer { preloaded_adapters: Vec, prefix_caching: bool, chunked_prefill: bool, - is_causal_lm: bool, + requires_block_allocator: bool, ) -> Self { let adapter_event = Arc::new(AdapterEvent { batching_task: Notify::new(), @@ -207,7 +207,7 @@ impl Infer { max_batch_total_tokens, prefix_caching, chunked_prefill, - is_causal_lm, + requires_block_allocator, ); // Initialize with base model adapter (empty) mapping to index 0 @@ -501,10 +501,12 @@ impl Infer { err })?; + let embed_params = request.parameters.unwrap_or_default(); + let (adapter_source, adapter_parameters) = extract_adapter_params( - request.parameters.adapter_id.clone(), - request.parameters.adapter_source.clone(), - request.parameters.adapter_parameters.clone(), + embed_params.adapter_id.clone(), + embed_params.adapter_source.clone(), + embed_params.adapter_parameters.clone(), ); let adapter_idx; @@ -520,7 +522,7 @@ impl Infer { } } - let api_token = request.parameters.api_token.clone(); + let api_token = embed_params.api_token.clone(); let adapter = Adapter::new( adapter_parameters, adapter_source.unwrap(), @@ -875,10 +877,12 @@ impl Infer { err })?; + let embed_params = request.parameters.clone().unwrap_or_default(); + let (adapter_source, adapter_parameters) = extract_adapter_params( - request.parameters.adapter_id.clone(), - request.parameters.adapter_source.clone(), - request.parameters.adapter_parameters.clone(), + embed_params.adapter_id.clone(), + embed_params.adapter_source.clone(), + embed_params.adapter_parameters.clone(), ); let adapter_idx; @@ -894,7 +898,7 @@ impl Infer { } } - let api_token = request.parameters.api_token.clone(); + let api_token = embed_params.api_token.clone(); let adapter = Adapter::new( adapter_parameters, adapter_source.unwrap(), diff --git a/router/src/lib.rs b/router/src/lib.rs index fd5b1049f..c3cf2cedc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1144,20 +1144,22 @@ pub(crate) struct EmbedParameters { pub api_token: Option, } -fn default_embed_parameters() -> EmbedParameters { - EmbedParameters { - adapter_id: None, - adapter_source: None, - adapter_parameters: None, - api_token: None, +impl Default for EmbedParameters { + fn default() -> Self { + Self { + adapter_id: None, + adapter_source: None, + adapter_parameters: None, + api_token: None, + } } } #[derive(Clone, Debug, Deserialize, ToSchema)] struct EmbedRequest { inputs: String, - #[serde(default = "default_embed_parameters")] - pub parameters: EmbedParameters, + #[serde(default)] + pub parameters: Option, } #[derive(Serialize, ToSchema)] @@ -1192,8 +1194,8 @@ struct CompatEmbedRequest { dimensions: Option, #[allow(dead_code)] user: Option, - #[serde(default = "default_embed_parameters")] - parameters: EmbedParameters, + #[serde(default)] + parameters: Option, } #[derive(Serialize, ToSchema)] @@ -1221,8 +1223,8 @@ struct BatchClassifyRequest { #[derive(Clone, Debug, Deserialize, ToSchema)] struct BatchEmbedRequest { inputs: Vec, - #[serde(default = "default_embed_parameters")] - parameters: EmbedParameters, + #[serde(default)] + parameters: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 9b2514623..cca4b05a1 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -41,7 +41,7 @@ impl AdapterScheduler { max_batch_total_tokens: u32, prefix_caching: bool, chunked_prefill: bool, - is_causal_lm: bool, + requires_block_allocator: bool, ) -> Self { let (sender, receiver) = mpsc::unbounded_channel(); @@ -59,7 +59,7 @@ impl AdapterScheduler { max_batch_total_tokens, prefix_caching, chunked_prefill, - is_causal_lm, + requires_block_allocator, )); Self { sender } @@ -125,7 +125,7 @@ async fn adapter_scheduler_task( max_batch_total_tokens: u32, prefix_caching: bool, chunked_prefill: bool, - is_causal_lm: bool, + requires_block_allocator: bool, ) { let mut state = AdapterSchedulerState::new( client, @@ -138,7 +138,7 @@ async fn adapter_scheduler_task( max_batch_total_tokens, prefix_caching, chunked_prefill, - is_causal_lm, + requires_block_allocator, ); while let Some(cmd) = receiver.recv().await { @@ -217,7 +217,7 @@ impl AdapterSchedulerState { max_batch_total_tokens: u32, prefix_caching: bool, chunked_prefill: bool, - is_causal_lm: bool, + requires_block_allocator: bool, ) -> Self { let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new( max_active_adapters, @@ -226,7 +226,7 @@ impl AdapterSchedulerState { let loader = AdapterLoader::new(client.clone()); // Only causal LMs require the block allocator, due to paged attention - let block_allocator = (!requires_padding && is_causal_lm).then(|| { + let block_allocator = (!requires_padding && requires_block_allocator).then(|| { BlockAllocator::new( max_batch_total_tokens, block_size, diff --git a/router/src/server.rs b/router/src/server.rs index 28d56efd8..405bcff84 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -486,12 +486,12 @@ async fn health( if health.shard_info().supports_embeddings { let embed_request = EmbedRequest { inputs: "San Francisco".to_string(), - parameters: EmbedParameters { + parameters: Some(EmbedParameters { adapter_id: None, adapter_source: None, adapter_parameters: None, api_token: None, - }, + }), }; match infer.embed(embed_request).await { Ok(_) => {} @@ -1354,7 +1354,7 @@ pub async fn run( shard_info.preloaded_adapters, prefix_caching, shard_info.chunked_prefill, - is_causal_lm, + shard_info.requires_block_allocator, ); // Duration buckets diff --git a/router/src/validation.rs b/router/src/validation.rs index 34c564541..e4194a1fb 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -814,6 +814,7 @@ mod tests { } #[tokio::test] + #[ignore] async fn test_validation_best_of_sampling() { let tokenizer = Some(get_tokenizer().await); let max_best_of = 2; diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 8c92085b6..e1aa6e30a 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -143,6 +143,10 @@ def supports_embeddings(self) -> bool: def supports_text_generation(self) -> bool: return False + @property + def requires_block_allocator(self) -> bool: + return False + @property def supports_classification(self) -> bool: return self.classification_head_enabled diff --git a/server/lorax_server/models/flash_distilbert.py b/server/lorax_server/models/flash_distilbert.py index 9d7fdcba4..3f3e4742b 100644 --- a/server/lorax_server/models/flash_distilbert.py +++ b/server/lorax_server/models/flash_distilbert.py @@ -124,6 +124,10 @@ def supports_embeddings(self) -> bool: def supports_text_generation(self) -> bool: return False + @property + def requires_block_allocator(self) -> bool: + return False + @property def supports_classification(self) -> bool: return self.classification_head_enabled diff --git a/server/lorax_server/models/model.py b/server/lorax_server/models/model.py index be80b3bdd..696dc2e80 100644 --- a/server/lorax_server/models/model.py +++ b/server/lorax_server/models/model.py @@ -147,6 +147,7 @@ def info(self) -> InfoResponse: supports_embeddings=self.supports_embeddings, supports_classification=self.supports_classification, chunked_prefill=self.supports_chunking, + requires_block_allocator=self.requires_block_allocator, ) @property @@ -163,6 +164,10 @@ def sliding_window_blocks(self) -> Optional[int]: def supports_embeddings(self) -> bool: return False + @property + def requires_block_allocator(self) -> bool: + return True + @property def supports_classification(self) -> bool: return False