Skip to content

Commit

Permalink
try sequential sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 22, 2023
1 parent 3a1be44 commit 85b7205
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion scripts/train_jat_tokenized.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ class DataTrainingArguments:
os.environ["WANDB_PROJECT"] = "jat"


class MyTrainer(Trainer):
def _get_train_sampler(self) -> None:
return None


def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

Expand Down Expand Up @@ -148,7 +153,7 @@ def main():
raise ValueError("Make sure to pass `--dispatch_batches False`.")

# Why the training continue after exauhsting the dataset? https://github.com/huggingface/transformers/issues/26635
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, tokenizer=processor)
trainer = MyTrainer(model=model, args=training_args, train_dataset=train_dataset, tokenizer=processor)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)


Expand Down

0 comments on commit 85b7205

Please sign in to comment.