Skip to content

Commit

Permalink
Fix stella model and client <> embedding compatibility (#717)
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh authored Dec 15, 2024
1 parent e314845 commit abf9f39
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 35 deletions.
9 changes: 5 additions & 4 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ service LoraxService {
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Embed
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Classify
/// Classify
rpc Classify (ClassifyRequest) returns (ClassifyResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
Expand Down Expand Up @@ -58,6 +58,7 @@ message InfoResponse {
bool supports_embeddings = 9;
bool supports_classification = 10;
bool chunked_prefill = 11;
bool requires_block_allocator = 12;
}

/// Empty request
Expand Down Expand Up @@ -118,11 +119,11 @@ message StoppingCriteriaParameters {
message Image {
/// Binary image data.
bytes data = 1;

/// Image MIME type.
string mimetype = 2;
}

message InputChunk {
oneof chunk {
/// Plain text data
Expand Down Expand Up @@ -331,7 +332,7 @@ message Entity {
message EntityList {
/// Request ID
uint64 request_id = 1;
/// Entities
/// Entities
repeated Entity entities = 2;
/// XXX
repeated uint32 input_ids = 4;
Expand Down
24 changes: 14 additions & 10 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl Infer {
preloaded_adapters: Vec<PreloadedAdapter>,
prefix_caching: bool,
chunked_prefill: bool,
is_causal_lm: bool,
requires_block_allocator: bool,
) -> Self {
let adapter_event = Arc::new(AdapterEvent {
batching_task: Notify::new(),
Expand All @@ -207,7 +207,7 @@ impl Infer {
max_batch_total_tokens,
prefix_caching,
chunked_prefill,
is_causal_lm,
requires_block_allocator,
);

// Initialize with base model adapter (empty) mapping to index 0
Expand Down Expand Up @@ -501,10 +501,12 @@ impl Infer {
err
})?;

let embed_params = request.parameters.unwrap_or_default();

let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
embed_params.adapter_id.clone(),
embed_params.adapter_source.clone(),
embed_params.adapter_parameters.clone(),
);

let adapter_idx;
Expand All @@ -520,7 +522,7 @@ impl Infer {
}
}

let api_token = request.parameters.api_token.clone();
let api_token = embed_params.api_token.clone();
let adapter = Adapter::new(
adapter_parameters,
adapter_source.unwrap(),
Expand Down Expand Up @@ -875,10 +877,12 @@ impl Infer {
err
})?;

let embed_params = request.parameters.clone().unwrap_or_default();

let (adapter_source, adapter_parameters) = extract_adapter_params(
request.parameters.adapter_id.clone(),
request.parameters.adapter_source.clone(),
request.parameters.adapter_parameters.clone(),
embed_params.adapter_id.clone(),
embed_params.adapter_source.clone(),
embed_params.adapter_parameters.clone(),
);

let adapter_idx;
Expand All @@ -894,7 +898,7 @@ impl Infer {
}
}

let api_token = request.parameters.api_token.clone();
let api_token = embed_params.api_token.clone();
let adapter = Adapter::new(
adapter_parameters,
adapter_source.unwrap(),
Expand Down
26 changes: 14 additions & 12 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1144,20 +1144,22 @@ pub(crate) struct EmbedParameters {
pub api_token: Option<String>,
}

fn default_embed_parameters() -> EmbedParameters {
EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
impl Default for EmbedParameters {
fn default() -> Self {
Self {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
}
}
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct EmbedRequest {
inputs: String,
#[serde(default = "default_embed_parameters")]
pub parameters: EmbedParameters,
#[serde(default)]
pub parameters: Option<EmbedParameters>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -1192,8 +1194,8 @@ struct CompatEmbedRequest {
dimensions: Option<i32>,
#[allow(dead_code)]
user: Option<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
#[serde(default)]
parameters: Option<EmbedParameters>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -1221,8 +1223,8 @@ struct BatchClassifyRequest {
#[derive(Clone, Debug, Deserialize, ToSchema)]
struct BatchEmbedRequest {
inputs: Vec<String>,
#[serde(default = "default_embed_parameters")]
parameters: EmbedParameters,
#[serde(default)]
parameters: Option<EmbedParameters>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down
12 changes: 6 additions & 6 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl AdapterScheduler {
max_batch_total_tokens: u32,
prefix_caching: bool,
chunked_prefill: bool,
is_causal_lm: bool,
requires_block_allocator: bool,
) -> Self {
let (sender, receiver) = mpsc::unbounded_channel();

Expand All @@ -59,7 +59,7 @@ impl AdapterScheduler {
max_batch_total_tokens,
prefix_caching,
chunked_prefill,
is_causal_lm,
requires_block_allocator,
));

Self { sender }
Expand Down Expand Up @@ -125,7 +125,7 @@ async fn adapter_scheduler_task(
max_batch_total_tokens: u32,
prefix_caching: bool,
chunked_prefill: bool,
is_causal_lm: bool,
requires_block_allocator: bool,
) {
let mut state = AdapterSchedulerState::new(
client,
Expand All @@ -138,7 +138,7 @@ async fn adapter_scheduler_task(
max_batch_total_tokens,
prefix_caching,
chunked_prefill,
is_causal_lm,
requires_block_allocator,
);

while let Some(cmd) = receiver.recv().await {
Expand Down Expand Up @@ -217,7 +217,7 @@ impl AdapterSchedulerState {
max_batch_total_tokens: u32,
prefix_caching: bool,
chunked_prefill: bool,
is_causal_lm: bool,
requires_block_allocator: bool,
) -> Self {
let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new(
max_active_adapters,
Expand All @@ -226,7 +226,7 @@ impl AdapterSchedulerState {
let loader = AdapterLoader::new(client.clone());

// Only causal LMs require the block allocator, due to paged attention
let block_allocator = (!requires_padding && is_causal_lm).then(|| {
let block_allocator = (!requires_padding && requires_block_allocator).then(|| {
BlockAllocator::new(
max_batch_total_tokens,
block_size,
Expand Down
6 changes: 3 additions & 3 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,12 @@ async fn health(
if health.shard_info().supports_embeddings {
let embed_request = EmbedRequest {
inputs: "San Francisco".to_string(),
parameters: EmbedParameters {
parameters: Some(EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
},
}),
};
match infer.embed(embed_request).await {
Ok(_) => {}
Expand Down Expand Up @@ -1354,7 +1354,7 @@ pub async fn run(
shard_info.preloaded_adapters,
prefix_caching,
shard_info.chunked_prefill,
is_causal_lm,
shard_info.requires_block_allocator,
);

// Duration buckets
Expand Down
1 change: 1 addition & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ mod tests {
}

#[tokio::test]
#[ignore]
async fn test_validation_best_of_sampling() {
let tokenizer = Some(get_tokenizer().await);
let max_best_of = 2;
Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/flash_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ def supports_embeddings(self) -> bool:
def supports_text_generation(self) -> bool:
return False

@property
def requires_block_allocator(self) -> bool:
return False

@property
def supports_classification(self) -> bool:
return self.classification_head_enabled
Expand Down
4 changes: 4 additions & 0 deletions server/lorax_server/models/flash_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def supports_embeddings(self) -> bool:
def supports_text_generation(self) -> bool:
return False

@property
def requires_block_allocator(self) -> bool:
return False

@property
def supports_classification(self) -> bool:
return self.classification_head_enabled
Expand Down
5 changes: 5 additions & 0 deletions server/lorax_server/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def info(self) -> InfoResponse:
supports_embeddings=self.supports_embeddings,
supports_classification=self.supports_classification,
chunked_prefill=self.supports_chunking,
requires_block_allocator=self.requires_block_allocator,
)

@property
Expand All @@ -163,6 +164,10 @@ def sliding_window_blocks(self) -> Optional[int]:
def supports_embeddings(self) -> bool:
return False

@property
def requires_block_allocator(self) -> bool:
return True

@property
def supports_classification(self) -> bool:
return False
Expand Down

0 comments on commit abf9f39

Please sign in to comment.