diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index a219b98ddb..614ac4314b 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -4,11 +4,11 @@ import contextlib import copy import logging +import math import os import tempfile from pathlib import Path from typing import Optional, Union -import math import torch from composer.core import Callback, Event, State, Time, TimeUnit @@ -238,8 +238,11 @@ def _save_checkpoint(self, state: State, logger: Logger): # we need a special case to identify we are on the last batch and should write the mlflow checkpoint is_last_batch = False if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH: - is_last_batch = int(state.timestamp.batch) % math.ceil(state.max_duration.value * state.dataloader_len) == 0 - if self.mlflow_registered_model_name is not None and ((elapsed_duration is not None and elapsed_duration >= 1.0) or is_last_batch): + is_last_batch = int(state.timestamp.batch) % math.ceil( + state.max_duration.value * state.dataloader_len) == 0 + if self.mlflow_registered_model_name is not None and ( + (elapsed_duration is not None and + elapsed_duration >= 1.0) or is_last_batch): components = {'model': new_model_instance} if original_tokenizer is not None: components['tokenizer'] = original_tokenizer diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 71c35f3723..d2f203d3a0 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -251,13 +251,14 @@ def test_callback_inits_with_defaults(): @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) @pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) -@pytest.mark.parametrize('hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) +@pytest.mark.parametrize( + 'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints', + [('3ba', '2ba', '7ba', 3, 4), ('1dur', '2ba', '1ep', 1, 4)]) def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, fsdp_state_dict_type: Optional[str], log_to_mlflow: bool, hf_save_interval: str, - save_interval: str, - max_duration: str, + save_interval: str, max_duration: str, expected_hf_checkpoints: int, expected_normal_checkpoints: int): delete_transformers_cache() @@ -269,7 +270,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, dataset_size = 14 precision_str = 'bfloat16' precision = torch.bfloat16 - batches_per_epoch = math.ceil(dataset_size / (device_batch_size*2)) + batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2)) checkpointer_callback = HuggingFaceCheckpointer( save_folder=os.path.join(tmp_path, 'checkpoints'),