Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MLM loss ignore idx #552

Merged
merged 10 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ awscli==1.33.33
nbval==0.11.0
# For NvFaidx equivalence tests
pyfaidx==0.8.1.3

# Temporary pin for pytorch-lightning until megatron callbacks in ProgressPrinter can get fixed.
# See https://nvidia.slack.com/archives/C02A7LYGHK8/p1734727482697309
pytorch-lightning<2.5.0
lightning<2.5.0
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from bionemo.esm2.api import ESM2GenericConfig, ESM2Model
from bionemo.esm2.data import tokenizer
from bionemo.llm.data.collate import MLM_LOSS_IGNORE_INDEX
from bionemo.llm.data.label2id_tokenizer import Label2IDTokenizer
from bionemo.llm.data.types import BertSample
from bionemo.llm.model.biobert.model import BioBertOutput
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
self.tokenizer = tokenizer
label_tokenizer = Label2IDTokenizer()
self.label_tokenizer = label_tokenizer.build_vocab("CHE")
self.label_cls_eos_id = MLM_LOSS_IGNORE_INDEX

def __len__(self) -> int:
"""Length of dataset."""
Expand Down Expand Up @@ -257,13 +259,13 @@ def _tokenize_labels(self, labels_sequence: str) -> Tensor:

# # for multi-label classification with BCEWithLogitsLoss
# tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
# cls_eos = torch.full((1, self.label_tokenizer.vocab_size), -1, dtype=tokenized_labels.dtype)
# cls_eos = torch.full((1, self.label_tokenizer.vocab_size), self.label_cls_eos_id, dtype=tokenized_labels.dtype)

# for multi-class (mutually exclusive) classification with CrossEntropyLoss
tokenized_labels = label_ids
cls_eos = torch.tensor([-1], dtype=tokenized_labels.dtype)
cls_eos = torch.tensor([self.label_cls_eos_id], dtype=tokenized_labels.dtype)

# add cls / eos labels with padding value -1 to have the same shape as tokenized_sequence
# add cls / eos label ids with padding value -100 to have the same shape as tokenized_sequence
labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
return labels

Expand Down
5 changes: 4 additions & 1 deletion sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
_warned_once: bool = False


MLM_LOSS_IGNORE_INDEX = -100 # This should match the masked value used in the MLM loss mask.


def padding_collate_fn(
batch: Sequence[_T],
padding_values: dict[str, int],
Expand Down Expand Up @@ -105,7 +108,7 @@ def bert_padding_collate_fn(
"text": padding_value,
"types": 0,
"attention_mask": False,
"labels": -100, # This should match the masked value used in the MLM loss mask.
"labels": MLM_LOSS_IGNORE_INDEX, # This should match the masked value used in the MLM loss mask.
"loss_mask": False,
"is_random": 0,
}
Expand Down
Loading