Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Oct 25, 2023
1 parent 9f1bfee commit 50edcf5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
9 changes: 6 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import contextlib
import copy
import logging
import math
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
import math

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand Down Expand Up @@ -238,8 +238,11 @@ def _save_checkpoint(self, state: State, logger: Logger):
# we need a special case to identify we are on the last batch and should write the mlflow checkpoint
is_last_batch = False
if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
is_last_batch = int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0
if self.mlflow_registered_model_name is not None and ((elapsed_duration is not None and elapsed_duration >= 1.0) or is_last_batch):
is_last_batch = int(state.timestamp.batch) % math.ceil(
state.max_duration.value * state.dataloader_len) == 0
if self.mlflow_registered_model_name is not None and (
(elapsed_duration is not None and
elapsed_duration >= 1.0) or is_last_batch):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
Expand Down
9 changes: 5 additions & 4 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,14 @@ def test_callback_inits_with_defaults():
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
@pytest.mark.parametrize('hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool,
hf_save_interval: str,
save_interval: str,
max_duration: str,
save_interval: str, max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int):
delete_transformers_cache()
Expand All @@ -269,7 +270,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
dataset_size = 14
precision_str = 'bfloat16'
precision = torch.bfloat16
batches_per_epoch = math.ceil(dataset_size / (device_batch_size*2))
batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))

checkpointer_callback = HuggingFaceCheckpointer(
save_folder=os.path.join(tmp_path, 'checkpoints'),
Expand Down

0 comments on commit 50edcf5

Please sign in to comment.