Skip to content

Commit

Permalink
pipe neox_args into GPT2Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Nov 22, 2023
1 parent 0372d4a commit 311ccf1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
9 changes: 9 additions & 0 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def build_the_dataset(
skip_warmup,
build_index_mappings=True,
label_prefix=None,
neox_args=None,
):
"""Build train/valid/test datasets."""

Expand All @@ -85,6 +86,7 @@ def build_the_dataset(
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
neox_args=neox_args
)
return dataset

Expand All @@ -98,6 +100,7 @@ def build_train_valid_test_datasets(
seq_length,
seed,
skip_warmup,
neox_args=None
):
"""Build train, valid, and test datasets."""

Expand Down Expand Up @@ -139,6 +142,7 @@ def build_dataset(index, name):
seq_length,
seed,
use_shared_fs=use_shared_fs,
neox_args=neox_args
)
return dataset

Expand Down Expand Up @@ -224,6 +228,7 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
neox_args=neox_args
)
)

Expand All @@ -238,6 +243,7 @@ def build_weighted_datasets(
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
neox_args=neox_args
)
)

Expand All @@ -252,6 +258,7 @@ def build_weighted_datasets(
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
neox_args=neox_args
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down Expand Up @@ -394,6 +401,7 @@ def build_train_valid_test_data_iterators(neox_args):
train_weights,
valid_weights,
test_weights,
neox_args=neox_args
)

if train_datasets:
Expand All @@ -414,6 +422,7 @@ def build_train_valid_test_data_iterators(neox_args):
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
neox_args=neox_args
)

# Build dataloders.
Expand Down
15 changes: 8 additions & 7 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,15 @@ def __init__(
build_index_mappings=True,
use_shared_fs=True,
label_dataset=None,
neox_args=None,
):

self.name = name
self.indexed_dataset = indexed_dataset
self.label_dataset = label_dataset

self.neox_args = neox_args

# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
Expand Down Expand Up @@ -116,21 +119,19 @@ def __getitem__(self, idx):

conditional_training = True
if conditional_training:
sample_text = self.neox_args.tokenizer.decode(sample[0])
sample_text = self.neox_args.tokenizer.detokenize(samples[0])

new_sample_text = ""
for i, (spacy_doc, meta) in enumerate(spacy_model.pipe((sample_text, "metadata"), n_process=8, as_tuples=True)):

for i, spacy_doc in enumerate(spacy_model.pipe([sample_text])):
for sent in spacy_doc.sents:
sent = sent.text_with_ws
new_sample_text += "<|endoftext|>"
new_sample_text += sent
encoded_new_text = self.neox_args.tokenizer.encode(new_sample_text)
print(len(encoded_new_text), len(samples[0])
encoded_new_text = self.neox_args.tokenizer.tokenize(new_sample_text)
# print(len(encoded_new_text), len(samples[0]))


for idx, (spacy_doc, meta) in enumerate(spacy_model.pipe(raw_texts, n_process=8, as_tuples=True)):
for sent in spacy_doc.sents:
yield {'text': sent.text_with_ws, 'meta': meta, 'idx': idx}
if len(datasets) == 1:
return {"text": np.array(samples[0], dtype=np.int64)}
else:
Expand Down

0 comments on commit 311ccf1

Please sign in to comment.