From 311ccf1e56c58ae68ba0f90c82e642a96d73ccb2 Mon Sep 17 00:00:00 2001 From: haileyschoelkopf Date: Wed, 22 Nov 2023 18:11:08 +0000 Subject: [PATCH] pipe neox_args into GPT2Dataset --- megatron/data/data_utils.py | 9 +++++++++ megatron/data/gpt2_dataset.py | 15 ++++++++------- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 513dd0e21..d5f166ad4 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -61,6 +61,7 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + neox_args=None, ): """Build train/valid/test datasets.""" @@ -85,6 +86,7 @@ def build_the_dataset( seed, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + neox_args=neox_args ) return dataset @@ -98,6 +100,7 @@ def build_train_valid_test_datasets( seq_length, seed, skip_warmup, + neox_args=None ): """Build train, valid, and test datasets.""" @@ -139,6 +142,7 @@ def build_dataset(index, name): seq_length, seed, use_shared_fs=use_shared_fs, + neox_args=neox_args ) return dataset @@ -224,6 +228,7 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=label_path, + neox_args=neox_args ) ) @@ -238,6 +243,7 @@ def build_weighted_datasets( seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + neox_args=neox_args ) ) @@ -252,6 +258,7 @@ def build_weighted_datasets( seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + neox_args=neox_args ) ) return train_datasets, valid_datasets, test_datasets @@ -394,6 +401,7 @@ def build_train_valid_test_data_iterators(neox_args): train_weights, valid_weights, test_weights, + neox_args=neox_args ) if train_datasets: @@ -414,6 +422,7 @@ def build_train_valid_test_data_iterators(neox_args): seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), + neox_args=neox_args ) # Build dataloders. diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 9d6933d48..9c70eba01 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -39,12 +39,15 @@ def __init__( build_index_mappings=True, use_shared_fs=True, label_dataset=None, + neox_args=None, ): self.name = name self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.neox_args = neox_args + # Checks assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] @@ -116,21 +119,19 @@ def __getitem__(self, idx): conditional_training = True if conditional_training: - sample_text = self.neox_args.tokenizer.decode(sample[0]) + sample_text = self.neox_args.tokenizer.detokenize(samples[0]) new_sample_text = "" - for i, (spacy_doc, meta) in enumerate(spacy_model.pipe((sample_text, "metadata"), n_process=8, as_tuples=True)): + + for i, spacy_doc in enumerate(spacy_model.pipe([sample_text])): for sent in spacy_doc.sents: sent = sent.text_with_ws new_sample_text += "<|endoftext|>" new_sample_text += sent - encoded_new_text = self.neox_args.tokenizer.encode(new_sample_text) - print(len(encoded_new_text), len(samples[0]) + encoded_new_text = self.neox_args.tokenizer.tokenize(new_sample_text) + # print(len(encoded_new_text), len(samples[0])) - for idx, (spacy_doc, meta) in enumerate(spacy_model.pipe(raw_texts, n_process=8, as_tuples=True)): - for sent in spacy_doc.sents: - yield {'text': sent.text_with_ws, 'meta': meta, 'idx': idx} if len(datasets) == 1: return {"text": np.array(samples[0], dtype=np.int64)} else: