Skip to content

Commit

Permalink
fix: add cls pooling as default for BERT variants (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Oct 17, 2024
1 parent 205f96c commit 750898d
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,20 @@ pub async fn run(
// Optionally download the pooling config.
if pooling.is_none() {
// If a pooling config exist, download it
let _ = download_pool_config(&api_repo).await;
let _ = download_pool_config(&api_repo).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});
}

// Download sentence transformers config
// Download legacy sentence transformers config
// We don't warn on failure as it is a legacy file
let _ = download_st_config(&api_repo).await;
// Download new sentence transformers config
let _ = download_new_st_config(&api_repo).await;
let _ = download_new_st_config(&api_repo).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});

// Download model from the Hub
download_artifacts(&api_repo)
Expand Down Expand Up @@ -387,10 +394,21 @@ fn get_backend_model_type(
None => {
// Load pooling config
let config_path = model_root.join("1_Pooling/config.json");
let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
let config: PoolConfig =
serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?

match fs::read_to_string(config_path) {
Ok(config) => {
let config: PoolConfig = serde_json::from_str(&config)
.context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?
}
Err(err) => {
if !config.model_type.to_lowercase().contains("bert") {
return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.");
}
tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling.");
text_embeddings_backend::Pool::Cls
}
}
}
};
Ok(text_embeddings_backend::ModelType::Embedding(pool))
Expand Down

0 comments on commit 750898d

Please sign in to comment.