Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Oct 22, 2024
1 parent afe2c40 commit 5e4d925
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def do_forward_pass(neox_args, model, inference=False):
tokens, attention_mask, position_ids = get_batch(
neox_args, context_tokens_tensor[:, : neox_args.seq_length]
)
logits = model((tokens, position_ids, attention_mask))
output = model((tokens, position_ids, attention_mask))
logits = output[0] if isinstance(output, tuple) else output


# reset to train mode, if model was in training before
if model_was_in_train:
Expand Down
1 change: 0 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def model_setup(yaml_list=None, param_dict=None, clear_data=True):
args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)
print("YAP")
model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
neox_args=args_loaded, use_cache=True
)
Expand Down
1 change: 1 addition & 0 deletions tests/model/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_checkpoint(self, param_dict, tmpdir):
reloaded_model,
reloaded_optimizer,
reloaded_lr_scheduler,
reloaded_reference_model,
args_reloaded,
) = model_setup(yaml_list=None, param_dict=param_dict, clear_data=False)
iteration = load_checkpoint(
Expand Down

0 comments on commit 5e4d925

Please sign in to comment.