diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 9c70eba01..b6824cab5 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -48,6 +48,13 @@ def __init__( self.neox_args = neox_args + # TODO: move spacy model def outside this fn, into __init__() + if neox_args.conditional_training: + import spacy + self.spacy_model = spacy.blank("en") + sentencizer = self.spacy_model.add_pipe("sentencizer") + self.spacy_model.max_length = 1e12 + # Checks assert np.min(documents) >= 0 assert np.max(documents) < indexed_dataset.sizes.shape[0] @@ -111,14 +118,8 @@ def __getitem__(self, idx): ) samples.append(np.concatenate(sample_list)) - # TODO: move spacy model def outside this fn, into __init__() - import spacy - spacy_model = spacy.blank("en") - sentencizer = spacy_model.add_pipe("sentencizer") - spacy_model.max_length = 1e12 - conditional_training = True - if conditional_training: + if self.neox_args.conditional_training: sample_text = self.neox_args.tokenizer.detokenize(samples[0]) new_sample_text = "" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 957960832..aac369773 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -788,6 +788,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): Warm up mmap files. """ + conditional_training: bool = False + """ + Whether to perform (decision-transformer style) conditional training. + """ + save: str = None """ Output directory to save checkpoints to.