Skip to content

Commit

Permalink
[feat] Integrate NanoBeIR datasets; use model.similarity by defau…
Browse files Browse the repository at this point in the history
…lt in evaluators (#2966)

* Added the possibility of masking the prompts if the tokenizer is left-padded.

* Simplify code

* Remove unrelated changes

* Add separate query and corpus prompts for IREvaluator

* Add query and corpus prompt_name

* Added NanoBEIREvaluator

* Rename, example and better logging

* Fix for all datasets

* Remove unrelated changes

* Remove unrelated changes

* Remove unrelated changes

* Remove unrelated changes

* Remove wrong function call to InformationRetrievalEvaluator

* Fix issue introduced in merge

* Flatten output dict, remove 'name' as we already know the dataset names

* Use the model similarity function by default for evaluators

- Fix 'tokens' typo -> 'dimension' in model card
- Group multiple evaluators with the same output keys together.
- Fix edge case where datasets without languages are excluded in model card
- Truncate really really long texts in model card
- Make default similarity_fn_name "cosine" rather than None

* Update tests due to similarity_fn_name defaulting to "cosine" now

* Specify all similarity_fn_names to be backwards compat. with old expected performance

* Fix loading the similarity fn from a config

And update 'str' type to Literals

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
ArthurCamara and tomaarsen authored Oct 29, 2024
1 parent 96a4bd7 commit 210ea8b
Show file tree
Hide file tree
Showing 12 changed files with 758 additions and 292 deletions.
14 changes: 9 additions & 5 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,22 +689,26 @@ def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
return input

@property
def similarity_fn_name(self) -> str | None:
def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]:
"""Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
Returns:
Optional[str]: The name of the similarity function. Can be None if not set, in which case any uses of
:meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise` default to "cosine".
Optional[str]: The name of the similarity function. Can be None if not set, in which case it will
default to "cosine" when first called.
Example:
>>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
>>> model.similarity_fn_name
'dot'
"""
if self._similarity_fn_name is None:
self.similarity_fn_name = SimilarityFunction.COSINE
return self._similarity_fn_name

@similarity_fn_name.setter
def similarity_fn_name(self, value: str | SimilarityFunction) -> None:
def similarity_fn_name(
self, value: Literal["cosine", "dot", "euclidean", "manhattan"] | SimilarityFunction
) -> None:
if isinstance(value, SimilarityFunction):
value = value.value
self._similarity_fn_name = value
Expand Down Expand Up @@ -1613,7 +1617,7 @@ def _load_sbert_model(
)

# Set score functions & prompts if not already overridden by the __init__ calls
if self.similarity_fn_name is None:
if self._similarity_fn_name is None:
self.similarity_fn_name = self._model_config.get("similarity_fn_name", None)
if not self.prompts:
self.prompts = self._model_config.get("prompts", {})
Expand Down
105 changes: 60 additions & 45 deletions sentence_transformers/evaluation/BinaryClassificationEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import os
from contextlib import nullcontext
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import numpy as np
from sklearn.metrics import average_precision_score
Expand Down Expand Up @@ -40,6 +40,7 @@ class BinaryClassificationEvaluator(SentenceEvaluator):
show_progress_bar (bool, optional): If true, prints a progress bar. Defaults to False.
write_csv (bool, optional): Write results to a CSV file. Defaults to True.
truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to. `None` uses the model's current truncation dimension. Defaults to None.
similarity_fn_names (Optional[List[Literal["cosine", "dot", "euclidean", "manhattan"]]], optional): The similarity functions to use. If not specified, defaults to the ``similarity_fn_name`` attribute of the model. Defaults to None.
Example:
::
Expand All @@ -59,7 +60,7 @@ class BinaryClassificationEvaluator(SentenceEvaluator):
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
labels=eval_dataset["label"],
name="quora-duplicates-dev",
name="quora_duplicates_dev",
)
results = binary_acc_evaluator(model)
'''
Expand All @@ -69,27 +70,9 @@ class BinaryClassificationEvaluator(SentenceEvaluator):
Precision with Cosine-Similarity: 65.81
Recall with Cosine-Similarity: 87.89
Average Precision with Cosine-Similarity: 76.03
Accuracy with Dot-Product: 81.60 (Threshold: 0.8352)
F1 with Dot-Product: 75.27 (Threshold: 0.7715)
Precision with Dot-Product: 65.81
Recall with Dot-Product: 87.89
Average Precision with Dot-Product: 76.03
Accuracy with Manhattan-Distance: 81.50 (Threshold: 12.0727)
F1 with Manhattan-Distance: 74.97 (Threshold: 15.2269)
Precision with Manhattan-Distance: 63.89
Recall with Manhattan-Distance: 90.68
Average Precision with Manhattan-Distance: 75.66
Accuracy with Euclidean-Distance: 81.60 (Threshold: 0.5741)
F1 with Euclidean-Distance: 75.27 (Threshold: 0.6760)
Precision with Euclidean-Distance: 65.81
Recall with Euclidean-Distance: 87.89
Average Precision with Euclidean-Distance: 76.03
'''
print(binary_acc_evaluator.primary_metric)
# => "quora-duplicates-dev_max_ap"
# => "quora_duplicates_dev_cosine_ap"
print(results[binary_acc_evaluator.primary_metric])
# => 0.760277070888393
"""
Expand All @@ -104,13 +87,14 @@ def __init__(
show_progress_bar: bool = False,
write_csv: bool = True,
truncate_dim: int | None = None,
similarity_fn_names: list[Literal["cosine", "dot", "euclidean", "manhattan"]] | None = None,
):
super().__init__()
self.sentences1 = sentences1
self.sentences2 = sentences2
self.labels = labels
self.truncate_dim = truncate_dim

self.primary_metric = "max_ap"
self.similarity_fn_names = similarity_fn_names or []

assert len(self.sentences1) == len(self.sentences2)
assert len(self.sentences1) == len(self.labels)
Expand All @@ -128,6 +112,10 @@ def __init__(

self.csv_file = "binary_classification_evaluation" + ("_" + name if name else "") + "_results.csv"
self.csv_headers = ["epoch", "steps"]

self._append_csv_headers(self.similarity_fn_names)

def _append_csv_headers(self, similarity_fn_names: list[str]) -> None:
metrics = [
"accuracy",
"accuracy_threshold",
Expand All @@ -137,7 +125,8 @@ def __init__(
"f1_threshold",
"ap",
]
for v in SimilarityFunction.possible_values():

for v in similarity_fn_names:
for m in metrics:
self.csv_headers.append(f"{v}_{m}")

Expand Down Expand Up @@ -180,6 +169,9 @@ def __call__(

logger.info(f"Binary Accuracy Evaluation of the model on the {self.name} dataset{out_txt}:")

if not self.similarity_fn_names:
self.similarity_fn_names = [model.similarity_fn_name]
self._append_csv_headers(self.similarity_fn_names)
scores = self.compute_metrices(model)

file_output_data = [epoch, steps]
Expand All @@ -206,14 +198,22 @@ def __call__(
for short_name, values in scores.items()
for metric, value in values.items()
}
metrics.update(
{f"max_{metric}": max(scores[short_name][metric] for short_name in scores) for metric in scores["cosine"]}
)
if len(self.similarity_fn_names) > 1:
metrics.update(
{
f"max_{metric}": max(scores[short_name][metric] for short_name in scores)
for metric in scores["cosine"]
}
)
self.primary_metric = "max_ap"
else:
self.primary_metric = f"{self.similarity_fn_names[0]}_ap"

metrics = self.prefix_name_to_metrics(metrics, self.name)
self.store_metrics_in_model_card_data(model, metrics)
return metrics

def compute_metrices(self, model):
def compute_metrices(self, model: SentenceTransformer) -> dict[str, dict[str, float]]:
with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
try:
# If the sentences are hashable, then we can use a set to avoid embedding the same sentences multiple
Expand All @@ -240,33 +240,48 @@ def compute_metrices(self, model):
embeddings1 = [emb_dict[sent] for sent in self.sentences1]
embeddings2 = [emb_dict[sent] for sent in self.sentences2]

cosine_scores = 1 - paired_cosine_distances(embeddings1, embeddings2)
manhattan_distances = paired_manhattan_distances(embeddings1, embeddings2)
euclidean_distances = paired_euclidean_distances(embeddings1, embeddings2)

embeddings1_np = np.asarray(embeddings1)
embeddings2_np = np.asarray(embeddings2)
dot_scores = np.sum(embeddings1_np * embeddings2_np, axis=-1)
similarity_fns = {
SimilarityFunction.COSINE.value: {
"score_fn": lambda x, y: 1 - paired_cosine_distances(x, y),
"name": "Cosine-Similarity",
"greater_is_better": True,
},
SimilarityFunction.DOT_PRODUCT.value: {
"score_fn": lambda x, y: np.sum(np.asarray(x) * np.asarray(y), axis=-1),
"name": "Dot-Product",
"greater_is_better": True,
},
SimilarityFunction.MANHATTAN.value: {
"score_fn": lambda x, y: -paired_manhattan_distances(x, y),
"name": "Manhattan-Distance",
"greater_is_better": False,
},
SimilarityFunction.EUCLIDEAN.value: {
"score_fn": lambda x, y: -paired_euclidean_distances(x, y),
"name": "Euclidean-Distance",
"greater_is_better": False,
},
}

labels = np.asarray(self.labels)
output_scores = {}
for short_name, name, scores, reverse in [
[SimilarityFunction.COSINE.value, "Cosine-Similarity", cosine_scores, True],
[SimilarityFunction.DOT_PRODUCT.value, "Dot-Product", dot_scores, True],
[SimilarityFunction.MANHATTAN.value, "Manhattan-Distance", manhattan_distances, False],
[SimilarityFunction.EUCLIDEAN.value, "Euclidean-Distance", euclidean_distances, False],
]:
acc, acc_threshold = self.find_best_acc_and_threshold(scores, labels, reverse)
f1, precision, recall, f1_threshold = self.find_best_f1_and_threshold(scores, labels, reverse)
ap = average_precision_score(labels, scores * (1 if reverse else -1))
for similarity_fn_name in self.similarity_fn_names:
similarity_fn = similarity_fns[similarity_fn_name]
scores = similarity_fn["score_fn"](embeddings1, embeddings2)
greater_is_better = similarity_fn["greater_is_better"]
name = similarity_fn["name"]

acc, acc_threshold = self.find_best_acc_and_threshold(scores, labels, greater_is_better)
f1, precision, recall, f1_threshold = self.find_best_f1_and_threshold(scores, labels, greater_is_better)
ap = average_precision_score(labels, scores * (1 if greater_is_better else -1))

logger.info(f"Accuracy with {name}: {acc * 100:.2f}\t(Threshold: {acc_threshold:.4f})")
logger.info(f"F1 with {name}: {f1 * 100:.2f}\t(Threshold: {f1_threshold:.4f})")
logger.info(f"Precision with {name}: {precision * 100:.2f}")
logger.info(f"Recall with {name}: {recall * 100:.2f}")
logger.info(f"Average Precision with {name}: {ap * 100:.2f}\n")

output_scores[short_name] = {
output_scores[similarity_fn_name] = {
"accuracy": acc,
"accuracy_threshold": acc_threshold,
"f1": f1,
Expand Down
Loading

0 comments on commit 210ea8b

Please sign in to comment.