Skip to content

Commit

Permalink
[feat] Trainer with prompts and prompt masking (#2964)
Browse files Browse the repository at this point in the history
* Added the possibility of masking the prompts if the tokenizer is left-padded.

* Simplify code

* Remove unrelated changes

* Move prompt_mask into the Transformer model

* Added query and corpus prompts to Information Retrieval Evaluator

* Fix for failing test

* Fix for pooling when mask is not passed

* Fix device placement for prompt_mask

* Revert left-padding changes

* Revert left-padding changes

* Added support to prompts in the Trainer

* Simplify logic and add prompt to eval dataset

* add prompt to test dataset

* Added support to prompts in the Trainer

* Simplify logic and add prompt to eval dataset

* add prompt to test dataset

* rename prompt to prompts

* Move prompts to collator

* rename to set_prompts

* Move prompts into data_collator

* Fix for pooling check

* Move prompt logic to Collator, add logic to add dataset column when prompt exists in the Trainer

* typo and init bug

* redundant initialization

* remove unused method

* add dtype of tensort

* Fix for dtype and None dataset

* Remove unused argument

* Fix typos

* Always tokenize a list, otherwise the prompt length is off

* Use a simple int as a prompt length instead of a tensor

* Add prompts via .set_transform/.map to Dataset rather than via Collator

* Remove dead code/TODO

* Move prompts to SentenceTransformersArguments

* Only include dataset_name if strictly needed, stricter tests

* (Unrelated) Warn if using a batch sampler with a streaming dataset

* Always return batch_size samples in transform

This is just safer & less hacky - I encountered a nasty bug where only returning 1 value (because we technically only need 1) results in all other samples being skipped. Not great.

* Fix bug with prompts + prompt_lengths & NoDuplicatesBatchSampler

* (Unrelated) add NanoBEIREvaluator to docs

* Add Training with Prompts docs + example script

This also already mentions the v3.3 release - a bit premature, but it's a tad simpler this way

* Slight updates to the docs

* Simplify/revert slightly in the data collator

---------

Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
ArthurCamara and tomaarsen authored Nov 8, 2024
1 parent 15d3898 commit 7be3eac
Show file tree
Hide file tree
Showing 14 changed files with 1,305 additions and 328 deletions.
5 changes: 5 additions & 0 deletions docs/package_reference/sentence_transformer/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
.. autoclass:: sentence_transformers.evaluation.InformationRetrievalEvaluator
```

## NanoBEIREvaluator
```eval_rst
.. autoclass:: sentence_transformers.evaluation.NanoBEIREvaluator
```

## MSEEvaluator
```eval_rst
.. autoclass:: sentence_transformers.evaluation.MSEEvaluator
Expand Down
1 change: 1 addition & 0 deletions docs/sentence_transformer/training/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Training Examples
../../../examples/training/multilingual/README
../../../examples/training/distillation/README
../../../examples/training/data_augmentation/README
../../../examples/training/prompts/README

.. toctree::
:maxdepth: 1
Expand Down
177 changes: 177 additions & 0 deletions examples/training/prompts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Training with Prompts

## What are Prompts?
Many modern embedding models are trained with "instructions" or "prompts" following the [INSTRUCTOR paper](https://arxiv.org/abs/2212.09741). These prompts are strings, prefixed to each text to be embedded, allowing the model to distinguish between different types of text.

For example, the [mixedbread-ai/mxbai-embed-large-v1](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1) model was trained with `Represent this sentence for searching relevant passages: ` as the prompt for all queries. This prompt is stored in the [model configuration](https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1/blob/main/config_sentence_transformers.json) under the prompt name `"query"`, so users can specify that `prompt_name` in `model.encode`:

```python
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
query_embedding = model.encode("What are Pandas?", prompt_name="query")
# or
# query_embedding = model.encode("What are Pandas?", prompt="Represent this sentence for searching relevant passages: ")
document_embeddings = model.encode([
"Pandas is a software library written for the Python programming language for data manipulation and analysis.",
"Pandas are a species of bear native to South Central China. They are also known as the giant panda or simply panda.",
"Koala bears are not actually bears, they are marsupials native to Australia.",
])
similarity = model.similarity(query_embedding, document_embeddings)
print(similarity)
# => tensor([[0.7594, 0.7560, 0.4674]])
```
See [Prompt Templates](https://sbert.net/examples/applications/computing-embeddings/README.html#prompt-templates) for more information about inference with prompts.

## Why would we train with Prompts?

The [INSTRUCTOR paper](https://arxiv.org/abs/2212.09741) showed that adding prompts or instructions before each text could improve model performance by an average of ~6%, with major gains especially for classification, clustering, and semantic textual similarity. Note that the performance improvements for retrieval are notably lower at 0.4% and 2.7% for small and large models, respectively.

<div align="center">
<img src="https://huggingface.co/tomaarsen/mpnet-base-nq-prompts/resolve/main/instructor.png" alt="instructor results" width="720"/>
</div>

More recently, the [BGE paper](https://arxiv.org/pdf/2309.07597) showed similar findings, showing about a 1.4% performance increase for retrieval if the query is prefixed with `Represent this sentence for searching relevant passages: `. The authors conclude that using instructions may substantially contribute to the quality of task-specific fine-tuning.

<div align="center">
<img src="https://huggingface.co/tomaarsen/mpnet-base-nq-prompts/resolve/main/bge.png" alt="bge results" width="720"/>
</div>

In essence, using instructions or prompts allows for improved performance as long as they are used both during training and inference.

## How do we train with Prompts?

```eval_rst
Since the v3.3.0 Sentence Transformers release, it's possible to finetune embedding models with prompts using the ``prompts`` argument in the :class:`~sentence_transformers.training_args.SentenceTransformerTrainingArguments` class. There are 4 separate accepted formats for this argument:
1. ``str``: A single prompt to use for all columns in all datasets. For example::
args = SentenceTransformerTrainingArguments(
...,
prompts="text: ",
...,
)
2. ``Dict[str, str]``: A dictionary mapping column names to prompts, applied to all datasets. For example::
args = SentenceTransformerTrainingArguments(
...,
prompts={
"query": "query: ",
"answer": "document: ",
},
...,
)
3. ``Dict[str, str]``: A dictionary mapping dataset names to prompts. This should only be used if your training/evaluation/test datasets are a :class:`~datasets.DatasetDict` or a dictionary of :class:`~datasets.Dataset`. For example::
args = SentenceTransformerTrainingArguments(
...,
prompts={
"stsb": "Represent this text for semantic similarity search: ",
"nq": "Represent this text for retrieval: ",
},
...,
)
4. ``Dict[str, Dict[str, str]]``: A dictionary mapping dataset names to dictionaries mapping column names to prompts. This should only be used if your training/evaluation/test datasets are a :class:`~datasets.DatasetDict` or a dictionary of :class:`~datasets.Dataset`. For example::
args = SentenceTransformerTrainingArguments(
...,
prompts={
"stsb": {
"sentence1": "sts: ",
"sentence2": "sts: ",
},
"nq": {
"query": "query: ",
"document": "document: ",
},
},
...,
)
Additionally, some research papers (`INSTRUCTOR <https://arxiv.org/abs/2212.09741>`_, `NV-Embed <https://arxiv.org/pdf/2405.17428>`_) exclude the prompt from the mean pooling step, such that it's only used in the Transformer blocks. In Sentence Transformers, this can be configured with the ``include_prompt`` argument/attribute in the :class:`~sentence_transformers.models.Pooling` module or via the :meth:`SentenceTransformer.set_pooling_include_prompt() <sentence_transformers.SentenceTransformer.set_pooling_include_prompt>` method. In my personal experience, models that include the prompt in the pooling tend to perform better.
```

### Training Script

```eval_rst
See the following script as an example of how to train with prompts in practice:
* `training_nq_prompts.py <training_nq_prompts.py>`_: This script finetunes `mpnet-base <https://huggingface.co/microsoft/mpnet-base>`_ on 100k query-answer pairs from the `natural-questions <https://huggingface.co/datasets/sentence-transformers/natural-questions>`_ dataset using the :class:`~sentence_transformers.losses.CachedMultipleNegativesRankingLoss` loss. The model is evaluated during training using the :class:`~sentence_transformers.evaluation.NanoBEIREvaluator`.
This script has two variables that affect 1) whether prompts are used and 2) whether prompts are included in the pooling. I have finetuned both ``mpnet-base`` and ``bert-base-uncased`` under the various different settings, resulting in a 0.66% and 0.90% relative improvements on ``NDCG@10`` at no extra cost.
.. tab:: Experiments with ``mpnet-base``
Running the script under various settings resulted in these checkpoints:
* `tomaarsen/mpnet-base-nq <https://huggingface.co/tomaarsen/mpnet-base-nq>`_
* `tomaarsen/mpnet-base-nq-prompts <https://huggingface.co/tomaarsen/mpnet-base-nq-prompts>`_
.. note::
``mpnet-base`` seems to be a tad unstable when training with prompts and excluding those prompts in the pooling: the loss spikes at some point, an effect not observed with e.g. ``bert-base-uncased``.
For these two models, the model trained with prompts consistently outperforms the baseline model all throughout training:
.. raw:: html
<img src="https://huggingface.co/tomaarsen/mpnet-base-nq-prompts/resolve/main/mpnet_base_nq_nanobeir.png" alt="NanoBEIR results of mpnet-base-nq vs mpnet-base-nq-prompts" width="480"/>
Additionally, the model trained with prompts includes these prompts in the training dataset details in the automatically generated model card: `tomaarsen/mpnet-base-nq-prompts#natural-questions <https://huggingface.co/tomaarsen/mpnet-base-nq-prompts#natural-questions>`_.
.. important::
If you train with prompts, then it's heavily recommended to store prompts in the model configuration as a mapping of prompt names to prompt strings. You can do this by initializing the :class:`~sentence_transformers.SentenceTransformer` with a ``prompts`` dictionary before saving it, updating the ``model.prompts`` of a loaded model before saving it, and/or updating the `config_sentence_transformers.json <https://huggingface.co/tomaarsen/mpnet-base-nq-prompts/blob/main/config_sentence_transformers.json>`_ file of the saved model.
After adding the prompts in the model configuration, the final usage of the prompt-trained model becomes::
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("tomaarsen/mpnet-base-nq-prompts")
query_embedding = model.encode("What are Pandas?", prompt_name="query")
document_embeddings = model.encode([
"Pandas is a software library written for the Python programming language for data manipulation and analysis.",
"Pandas are a species of bear native to South Central China. They are also known as the giant panda or simply panda.",
"Koala bears are not actually bears, they are marsupials native to Australia.",
],
prompt_name="document",
)
similarity = model.similarity(query_embedding, document_embeddings)
print(similarity)
# => tensor([[0.4725, 0.7339, 0.4369]])
.. tab:: Experiments with ``bert-base-uncased``
Running the script under various settings resulted in these checkpoints:
* `tomaarsen/bert-base-nq <https://huggingface.co/tomaarsen/bert-base-nq>`_
* `tomaarsen/bert-base-nq-prompts <https://huggingface.co/tomaarsen/bert-base-nq-prompts>`_
* `tomaarsen/bert-base-nq-prompts-exclude-pooling-prompts <https://huggingface.co/tomaarsen/bert-base-nq-prompts-exclude-pooling-prompts>`_
For these three models, the model trained with prompts consistently outperforms the baseline model all throughout training, except for the very first evaluation. The model that excludes the prompt in the mean pooling consistently performs notably worse than either of the other two.
.. raw:: html
<img src="https://huggingface.co/tomaarsen/mpnet-base-nq-prompts/resolve/main/bert_base_nq_nanobeir.png" alt="NanoBEIR results" width="480"/>
Additionally, the model trained with prompts includes these prompts in the training dataset details in the automatically generated model card: `tomaarsen/bert-base-nq-prompts#natural-questions <https://huggingface.co/tomaarsen/bert-base-nq-prompts#natural-questions>`_.
.. important::
If you train with prompts, then it's heavily recommended to store prompts in the model configuration as a mapping of prompt names to prompt strings. You can do this by initializing the :class:`~sentence_transformers.SentenceTransformer` with a ``prompts`` dictionary before saving it, updating the ``model.prompts`` of a loaded model before saving it, and/or updating the `config_sentence_transformers.json <https://huggingface.co/tomaarsen/mpnet-base-nq-prompts/blob/main/config_sentence_transformers.json>`_ file of the saved model.
After adding the prompts in the model configuration, the final usage of the prompt-trained model becomes::
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("tomaarsen/bert-base-nq-prompts")
query_embedding = model.encode("What are Pandas?", prompt_name="query")
document_embeddings = model.encode([
"Pandas is a software library written for the Python programming language for data manipulation and analysis.",
"Pandas are a species of bear native to South Central China. They are also known as the giant panda or simply panda.",
"Koala bears are not actually bears, they are marsupials native to Australia.",
],
prompt_name="document",
)
similarity = model.similarity(query_embedding, document_embeddings)
print(similarity)
# => tensor([[0.3955, 0.8226, 0.5706]])
```
114 changes: 114 additions & 0 deletions examples/training/prompts/training_nq_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# See https://huggingface.co/collections/tomaarsen/training-with-prompts-672ce423c85b4d39aed52853 for some already trained models

import logging
import random

import numpy
import torch
from datasets import Dataset, load_dataset

from sentence_transformers import (
SentenceTransformer,
SentenceTransformerModelCardData,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers

logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
random.seed(12)
torch.manual_seed(12)
numpy.random.seed(12)

# Feel free to adjust these variables:
use_prompts = True
include_prompts_in_pooling = True

# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
"microsoft/mpnet-base",
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on Natural Questions pairs",
),
)
model.set_pooling_include_prompt(include_prompts_in_pooling)

# 2. (Optional) Define prompts
if use_prompts:
query_prompt = "query: "
corpus_prompt = "document: "
prompts = {
"query": query_prompt,
"answer": corpus_prompt,
}

# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/natural-questions", split="train")
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
train_dataset: Dataset = dataset_dict["train"]
eval_dataset: Dataset = dataset_dict["test"]

# 4. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16)

# 5. (Optional) Specify training arguments
run_name = "mpnet-base-nq"
if use_prompts:
run_name += "-prompts"
if not include_prompts_in_pooling:
run_name += "-exclude-pooling-prompts"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=256,
per_device_eval_batch_size=256,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=50,
save_total_limit=2,
logging_steps=5,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
prompts=prompts if use_prompts else None,
)

# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = NanoBEIREvaluator(
query_prompts=query_prompt if use_prompts else None,
corpus_prompts=corpus_prompt if use_prompts else None,
)
dev_evaluator(model)

# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()

# (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)

# 8. Save the trained model
model.save_pretrained(f"models/{run_name}/final")

# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
8 changes: 6 additions & 2 deletions index.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
.. tip::
.. note::

Sentence Transformers v3.2 just released, introducing the ONNX and OpenVINO backends for Sentence Transformer models. Read `SentenceTransformer > Usage > Speeding up Inference <docs/sentence_transformer/usage/efficiency.html>`_ to learn more about the new backends and what they can mean for your inference speed.
Sentence Transformers v3.2 recently released, introducing the ONNX and OpenVINO backends for Sentence Transformer models. Read `SentenceTransformer > Usage > Speeding up Inference <docs/sentence_transformer/usage/efficiency.html>`_ to learn more about the new backends and what they can mean for your inference speed.

.. note::

Sentence Transformers v3.3 just released, introducing training with Prompts. Read `SentenceTransformer > Training Examples > Training with Prompts <examples/training/prompts/README.html>`_ to learn more about how you can use them to train stronger models.

SentenceTransformers Documentation
==================================
Expand Down
7 changes: 6 additions & 1 deletion sentence_transformers/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
column_names.remove(label_column)
break

# Extract the feature columns
for column_name in column_names:
# If the prompt length has been set, we should add it to the batch
if column_name.endswith("_prompt_length") and column_name[: -len("_prompt_length")] in column_names:
batch[column_name] = torch.tensor([row[column_name] for row in features], dtype=torch.int)
continue

tokenized = self.tokenize_fn([row[column_name] for row in features])
for key, value in tokenized.items():
batch[f"{column_name}_{key}"] = value

return batch

def maybe_warn_about_column_order(self, column_names: list[str]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions sentence_transformers/evaluation/NanoBEIREvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ class NanoBEIREvaluator(SentenceEvaluator):
datasets = ["QuoraRetrieval", "MSMARCO"]
query_prompts = {
"QuoraRetrieval": "Instruct: Given a question, retrieve questions that are semantically equivalent to the given question\nQuery: ",
"MSMARCO": "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: "
"QuoraRetrieval": "Instruct: Given a question, retrieve questions that are semantically equivalent to the given question\\nQuery: ",
"MSMARCO": "Instruct: Given a web search query, retrieve relevant passages that answer the query\\nQuery: "
}
evaluator = NanoBEIREvaluator(
Expand Down
Loading

0 comments on commit 7be3eac

Please sign in to comment.