Skip to content

Commit

Permalink
this works :D
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Apr 4, 2024
1 parent b5ae3e8 commit 16a8113
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 22 deletions.
8 changes: 8 additions & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ impl Client {
Ok(response)
}

/// Embed
#[instrument(skip(self))]
pub async fn embed(&mut self, inputs: String) -> Result<EmbedResponse> {
let request = tonic::Request::new(EmbedRequest { inputs }).inject_context();
let response = self.stub.embed(request).await?.into_inner();
Ok(response)
}

/// Get model health
#[instrument(skip(self))]
pub async fn health(&mut self) -> Result<HealthResponse> {
Expand Down
12 changes: 12 additions & 0 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::pb::generate::v1::EmbedResponse;
/// Multi shard Client
use crate::{
AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation,
Expand Down Expand Up @@ -153,6 +154,17 @@ impl ShardedClient {
merge_generations(results?)
}

/// Get the model info
#[instrument(skip(self))]
pub async fn embed(&mut self, inputs: String) -> Result<Vec<EmbedResponse>> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.embed(inputs.clone())))
.collect();
join_all(futures).await.into_iter().collect()
}

pub async fn download_adapter(
&mut self,
adapter_parameters: AdapterParameters,
Expand Down
2 changes: 2 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,12 @@ pub(crate) enum CompletionFinishReason {
ToolCalls,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct EmbedRequest {
inputs: String,
}

#[derive(Serialize, ToSchema)]
struct EmbedResponse {
embeddings: Vec<f32>,
}
Expand Down
19 changes: 16 additions & 3 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ pub async fn run(
let generation_health = Arc::new(AtomicBool::new(false));
let health_ext = Health::new(client.clone(), generation_health.clone());
let infer = Infer::new(
client,
client.clone(),
validation,
waiting_served_ratio,
max_batch_prefill_tokens,
Expand Down Expand Up @@ -1108,6 +1108,7 @@ pub async fn run(
.route("/", post(compat_generate))
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/embed", post(embed))
.route("/generate_stream", post(generate_stream))
.route("/v1/completions", post(completions_v1))
.route("/v1/chat/completions", post(chat_completions_v1))
Expand All @@ -1123,6 +1124,7 @@ pub async fn run(
.route("/metrics", get(metrics))
.route("/tokenize", post(tokenize))
.layer(Extension(info))
.layer(Extension(client.clone()))
.layer(Extension(request_logger_sender.clone()))
.layer(Extension(health_ext.clone()))
.layer(Extension(compat_return_full_text))
Expand Down Expand Up @@ -1279,10 +1281,21 @@ impl From<InferError> for Event {
)]
#[instrument(skip_all)]
async fn embed(
mut client: Extension<ShardedClient>,
Json(req): Json<EmbedRequest>,
) -> Result<Json<EmbedResponse>, (StatusCode, Json<ErrorResponse>)> {
tracing::debug!("Input: {}", req.inputs);
Ok(Json(EmbedResponse { embeddings: vec![] }))
tracing::info!("Input: {}", req.inputs);
let input = req.inputs.clone();
let embeddings = client.embed(input).await.unwrap();
// initialize the values array with the first embedding
let values = embeddings
.get(0)
.map(|emb| emb.embeddings.as_ref().map(|emb| emb.values.clone()))
.flatten()
.unwrap_or_default();
Ok(Json(EmbedResponse {
embeddings: values.clone(),
}))
}

/// Tokenize inputs
Expand Down
2 changes: 1 addition & 1 deletion server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from lorax_server.models.bloom import BLOOMSharded
from lorax_server.models.causal_lm import CausalLM
from lorax_server.models.flash_bert import FlashBert
from lorax_server.models.flash_causal_lm import FlashCausalLM
from lorax_server.models.galactica import GalacticaSharded
from lorax_server.models.model import Model
Expand Down Expand Up @@ -94,6 +93,7 @@ def get_model(
)

if "WhereIsAI/UAE-Large-V1" in model_id:
from lorax_server.models.flash_bert import FlashBert
return FlashBert(model_id, revision=revision, dtype=dtype)

if model_id.startswith("bigcode/") or model_type == "gpt_bigcode":
Expand Down
19 changes: 1 addition & 18 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from transformers import AutoTokenizer
from opentelemetry import trace

from dataclasses import dataclass
from abc import ABC

from lorax_server.utils.layers import FastLayerNorm
from lorax_server.utils.flash_attn import attention
from lorax_server.models import Model
from lorax_server.models.types import FlashBatch
from lorax_server.pb.generate_pb2 import Embedding

from lorax_server.utils import (
Expand All @@ -25,22 +24,6 @@
tracer = trace.get_tracer(__name__)


@dataclass
class FlashBatch(ABC):
input_ids: torch.Tensor
token_type_ids: torch.Tensor
position_ids: torch.Tensor

cu_seqlens: torch.Tensor
max_s: int
size: int

def __len__(self):
return self.size

def from_pb(self, *args, **kwargs):
return None


class BertEmbeddings:
def __init__(self, prefix, weights, device, dtype, config: BertConfig):
Expand Down
16 changes: 16 additions & 0 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,19 @@ def to_pb(self) -> generate_pb2.Generation:
next_tokens=self.next_tokens.to_pb(),
generated_text=self.generated_text.to_pb() if self.generated_text is not None else None,
)

@dataclass
class FlashBatch(ABC):
input_ids: torch.Tensor
token_type_ids: torch.Tensor
position_ids: torch.Tensor

cu_seqlens: torch.Tensor
max_s: int
size: int

def __len__(self):
return self.size

def from_pb(self, *args, **kwargs):
return None

0 comments on commit 16a8113

Please sign in to comment.