From 3915abee4cba364a83f18cc4b8c2bb571cce3bac Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 31 Dec 2024 15:22:18 -0500 Subject: [PATCH] make sure padding is labeled as -100 for pretraining (#2227) --- src/axolotl/utils/data/pretraining.py | 44 +++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index 16f38218cd..f493db70eb 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -28,8 +28,10 @@ def encode_pretraining( ) # Convert to PyTorch tensors input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + targets = [torch.tensor(seq) for seq in res["input_ids"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] new_input_ids = [] + new_labels = [] new_attention_mask = [] # Append EOS and PAD tokens to input_ids, and correct attention_mask for i, _ in enumerate(input_ids): @@ -40,22 +42,34 @@ def encode_pretraining( ), dim=0, ) + targets[i] = torch.cat( + ( + targets[i], + torch.tensor([tokenizer.eos_token_id, -100]), + ), + dim=0, + ) attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) # Concatenate tokens so that their lengths are less than max_tokens buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) - for ids, mask in zip(input_ids, attention_mask): + for ids, labels, mask in zip(input_ids, targets, attention_mask): if buffer_input_ids.numel() == max_tokens: new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) elif buffer_input_ids.numel() + ids.numel() <= max_tokens: buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) else: buffer_input_ids = torch.cat( @@ -69,6 +83,17 @@ def encode_pretraining( ), dim=0, ) + buffer_labels = torch.cat( + ( + buffer_labels, + torch.full( + (max_tokens - buffer_labels.numel(),), + -100, + dtype=torch.long, + ), + ), + dim=0, + ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, @@ -81,11 +106,14 @@ def encode_pretraining( dim=0, ) new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_labels = torch.tensor([], dtype=torch.long) buffer_attention_mask = torch.tensor([], dtype=torch.long) buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_labels = torch.cat((buffer_labels, labels), dim=0) buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) if buffer_input_ids.numel() > 0: # for any leftover tokens @@ -101,6 +129,17 @@ def encode_pretraining( ), dim=0, ) + buffer_labels = torch.cat( + ( + buffer_labels, + torch.full( + (max_tokens - buffer_labels.numel(),), + -100, + dtype=torch.long, + ), + ), + dim=0, + ) buffer_attention_mask = torch.cat( ( buffer_attention_mask, @@ -113,11 +152,12 @@ def encode_pretraining( dim=0, ) new_input_ids.append(buffer_input_ids) + new_labels.append(buffer_labels) new_attention_mask.append(buffer_attention_mask) ret = { "input_ids": [seq.tolist() for seq in new_input_ids], - "labels": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_labels], "attention_mask": [seq.tolist() for seq in new_attention_mask], }