Skip to content

Commit

Permalink
update flashattn, fix ppo save model
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Sep 11, 2023
1 parent b218c27 commit 0fbece8
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 518 deletions.
12 changes: 6 additions & 6 deletions src/llmtuner/extras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "Tr
r"""
Event called after a checkpoint save.
"""
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)
return control
if args.should_save:
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(output_dir)

def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of training.
"""
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)
return control
if args.should_save:
getattr(kwargs.get("model"), "pretrained_model").save_pretrained(args.output_dir)


class LogCallback(TrainerCallback):
Expand Down
Loading

0 comments on commit 0fbece8

Please sign in to comment.