diff --git a/router/client/src/client.rs b/router/client/src/client.rs index ee68c6d96..6b7157885 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -64,6 +64,14 @@ impl Client { Ok(response) } + /// Embed + #[instrument(skip(self))] + pub async fn embed(&mut self, inputs: String) -> Result { + let request = tonic::Request::new(EmbedRequest { inputs }).inject_context(); + let response = self.stub.embed(request).await?.into_inner(); + Ok(response) + } + /// Get model health #[instrument(skip(self))] pub async fn health(&mut self) -> Result { diff --git a/router/client/src/sharded_client.rs b/router/client/src/sharded_client.rs index 7427114f4..1a6f885fb 100644 --- a/router/client/src/sharded_client.rs +++ b/router/client/src/sharded_client.rs @@ -1,3 +1,4 @@ +use crate::pb::generate::v1::EmbedResponse; /// Multi shard Client use crate::{ AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation, @@ -153,6 +154,17 @@ impl ShardedClient { merge_generations(results?) } + /// Get the model info + #[instrument(skip(self))] + pub async fn embed(&mut self, inputs: String) -> Result> { + let futures: Vec<_> = self + .clients + .iter_mut() + .map(|client| Box::pin(client.embed(inputs.clone()))) + .collect(); + join_all(futures).await.into_iter().collect() + } + pub async fn download_adapter( &mut self, adapter_parameters: AdapterParameters, diff --git a/router/src/lib.rs b/router/src/lib.rs index 4d820aa50..db90144fc 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -629,10 +629,12 @@ pub(crate) enum CompletionFinishReason { ToolCalls, } +#[derive(Clone, Debug, Deserialize, ToSchema)] struct EmbedRequest { inputs: String, } +#[derive(Serialize, ToSchema)] struct EmbedResponse { embeddings: Vec, } diff --git a/router/src/server.rs b/router/src/server.rs index 91ceeaebd..6efa1f1ff 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -977,7 +977,7 @@ pub async fn run( let generation_health = Arc::new(AtomicBool::new(false)); let health_ext = Health::new(client.clone(), generation_health.clone()); let infer = Infer::new( - client, + client.clone(), validation, waiting_served_ratio, max_batch_prefill_tokens, @@ -1108,6 +1108,7 @@ pub async fn run( .route("/", post(compat_generate)) .route("/info", get(get_model_info)) .route("/generate", post(generate)) + .route("/embed", post(embed)) .route("/generate_stream", post(generate_stream)) .route("/v1/completions", post(completions_v1)) .route("/v1/chat/completions", post(chat_completions_v1)) @@ -1123,6 +1124,7 @@ pub async fn run( .route("/metrics", get(metrics)) .route("/tokenize", post(tokenize)) .layer(Extension(info)) + .layer(Extension(client.clone())) .layer(Extension(request_logger_sender.clone())) .layer(Extension(health_ext.clone())) .layer(Extension(compat_return_full_text)) @@ -1279,10 +1281,21 @@ impl From for Event { )] #[instrument(skip_all)] async fn embed( + mut client: Extension, Json(req): Json, ) -> Result, (StatusCode, Json)> { - tracing::debug!("Input: {}", req.inputs); - Ok(Json(EmbedResponse { embeddings: vec![] })) + tracing::info!("Input: {}", req.inputs); + let input = req.inputs.clone(); + let embeddings = client.embed(input).await.unwrap(); + // initialize the values array with the first embedding + let values = embeddings + .get(0) + .map(|emb| emb.embeddings.as_ref().map(|emb| emb.values.clone())) + .flatten() + .unwrap_or_default(); + Ok(Json(EmbedResponse { + embeddings: values.clone(), + })) } /// Tokenize inputs diff --git a/server/lorax_server/models/__init__.py b/server/lorax_server/models/__init__.py index 1cbd4abd3..121cc3843 100644 --- a/server/lorax_server/models/__init__.py +++ b/server/lorax_server/models/__init__.py @@ -6,7 +6,6 @@ from lorax_server.models.bloom import BLOOMSharded from lorax_server.models.causal_lm import CausalLM -from lorax_server.models.flash_bert import FlashBert from lorax_server.models.flash_causal_lm import FlashCausalLM from lorax_server.models.galactica import GalacticaSharded from lorax_server.models.model import Model @@ -94,6 +93,7 @@ def get_model( ) if "WhereIsAI/UAE-Large-V1" in model_id: + from lorax_server.models.flash_bert import FlashBert return FlashBert(model_id, revision=revision, dtype=dtype) if model_id.startswith("bigcode/") or model_type == "gpt_bigcode": diff --git a/server/lorax_server/models/flash_bert.py b/server/lorax_server/models/flash_bert.py index 001cd584a..ef711ea3b 100644 --- a/server/lorax_server/models/flash_bert.py +++ b/server/lorax_server/models/flash_bert.py @@ -8,12 +8,11 @@ from transformers import AutoTokenizer from opentelemetry import trace -from dataclasses import dataclass -from abc import ABC from lorax_server.utils.layers import FastLayerNorm from lorax_server.utils.flash_attn import attention from lorax_server.models import Model +from lorax_server.models.types import FlashBatch from lorax_server.pb.generate_pb2 import Embedding from lorax_server.utils import ( @@ -25,22 +24,6 @@ tracer = trace.get_tracer(__name__) -@dataclass -class FlashBatch(ABC): - input_ids: torch.Tensor - token_type_ids: torch.Tensor - position_ids: torch.Tensor - - cu_seqlens: torch.Tensor - max_s: int - size: int - - def __len__(self): - return self.size - - def from_pb(self, *args, **kwargs): - return None - class BertEmbeddings: def __init__(self, prefix, weights, device, dtype, config: BertConfig): diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 497c65f79..3bf1e961a 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -124,3 +124,19 @@ def to_pb(self) -> generate_pb2.Generation: next_tokens=self.next_tokens.to_pb(), generated_text=self.generated_text.to_pb() if self.generated_text is not None else None, ) + +@dataclass +class FlashBatch(ABC): + input_ids: torch.Tensor + token_type_ids: torch.Tensor + position_ids: torch.Tensor + + cu_seqlens: torch.Tensor + max_s: int + size: int + + def __len__(self): + return self.size + + def from_pb(self, *args, **kwargs): + return None \ No newline at end of file