Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

normalize to float in NanoBEIREvaluator, InformationRetrievalEvaluator, MSEEvaluator #3095

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 50 additions & 51 deletions examples/training/data_augmentation/train_sts_seed_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,25 @@
python train_sts_seed_optimization.py bert-base-uncased 10 0.3
"""

import csv
import gzip
import logging
import math
import os
import random
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset

from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models, util
from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# Check if dataset exists. If not, download and extract it
sts_dataset_path = "datasets/stsbenchmark.tsv.gz"

if not os.path.exists(sts_dataset_path):
util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)


# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
Expand Down Expand Up @@ -85,49 +76,57 @@

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
for row in reader:
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)

if row["split"] == "dev":
dev_samples.append(inp_example)
elif row["split"] == "test":
test_samples.append(inp_example)
else:
train_samples.append(inp_example)

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb
train_dataset = load_dataset("sentence-transformers/stsb", split="train")
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
logging.info(train_dataset)

logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")
train_loss = losses.CosineSimilarityLoss(model=model)

# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
# 4. Define an evaluator for use during training.
dev_evaluator = EmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
)

# Stopping and Evaluating after 30% of training data (less than 1 epoch)
# We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed
steps_per_epoch = math.ceil(len(train_dataloader) * stop_after)

logging.info(f"Warmup-steps: {warmup_steps}")
steps_per_epoch = math.ceil(len(train_dataset) * stop_after)

logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data")

# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
evaluation_steps=1000,
warmup_steps=warmup_steps,
output_path=model_save_path,
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=model_save_path,
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
warmup_ratio=0.1,
max_steps=steps_per_epoch,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
evaluation_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
logging_steps=1000,
run_name="sts", # Will be used in W&B if `wandb` is installed
)

# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
)
trainer.train()
2 changes: 1 addition & 1 deletion sentence_transformers/evaluation/SentenceEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __call__(
def prefix_name_to_metrics(self, metrics: dict[str, float], name: str) -> dict[str, float]:
if not name:
return metrics
metrics = {name + "_" + key: value for key, value in metrics.items()}
metrics = {name + "_" + key: float(value) for key, value in metrics.items()}
if hasattr(self, "primary_metric") and not self.primary_metric.startswith(name + "_"):
self.primary_metric = name + "_" + self.primary_metric
return metrics
Expand Down