Skip to content

Commit

Permalink
add arg for conditional training
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Nov 22, 2023
1 parent 311ccf1 commit 1dd4054
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
15 changes: 8 additions & 7 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def __init__(

self.neox_args = neox_args

# TODO: move spacy model def outside this fn, into __init__()
if neox_args.conditional_training:
import spacy
self.spacy_model = spacy.blank("en")
sentencizer = self.spacy_model.add_pipe("sentencizer")
self.spacy_model.max_length = 1e12

# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
Expand Down Expand Up @@ -111,14 +118,8 @@ def __getitem__(self, idx):
)
samples.append(np.concatenate(sample_list))

# TODO: move spacy model def outside this fn, into __init__()
import spacy
spacy_model = spacy.blank("en")
sentencizer = spacy_model.add_pipe("sentencizer")
spacy_model.max_length = 1e12

conditional_training = True
if conditional_training:
if self.neox_args.conditional_training:
sample_text = self.neox_args.tokenizer.detokenize(samples[0])

new_sample_text = ""
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,11 @@ class NeoXArgsTraining(NeoXArgsTemplate):
Warm up mmap files.
"""

conditional_training: bool = False
"""
Whether to perform (decision-transformer style) conditional training.
"""

save: str = None
"""
Output directory to save checkpoints to.
Expand Down

0 comments on commit 1dd4054

Please sign in to comment.