From d2ddb834650b085337c4d914f77bb80c76201e9d Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:25:13 -0800 Subject: [PATCH] Add Unity Catalog support to HF checkpointer (#721) --- llmfoundry/callbacks/hf_checkpointer.py | 31 +++++++++++-------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4f400738e4..e02bf03693 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -14,9 +14,10 @@ from composer.core import Callback, Event, State, Time, TimeUnit from composer.core.state import fsdp_state_dict_type_context from composer.loggers import Logger, MLFlowLogger -from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel -from composer.utils import dist, format_name_with_dist_and_time, parse_uri +from composer.utils import (dist, format_name_with_dist_and_time, + maybe_create_remote_uploader_downloader_from_uri, + parse_uri) from composer.utils.misc import create_interval_scheduler from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -57,8 +58,7 @@ def __init__( mlflow_registered_model_name: Optional[str] = None, mlflow_logging_config: Optional[dict] = None, ): - self.backend, self.bucket_name, self.save_dir_format_str = parse_uri( - save_folder) + _, _, self.save_dir_format_str = parse_uri(save_folder) self.overwrite = overwrite self.precision = precision self.dtype = { @@ -93,13 +93,11 @@ def __init__( self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) - self.upload_to_object_store = (self.backend != '') - if self.upload_to_object_store: - self.remote_ud = RemoteUploaderDownloader( - bucket_uri=f'{self.backend}://{self.bucket_name}', - num_concurrent_uploads=4) - else: - self.remote_ud = None + + self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( + save_folder, loggers=[]) + if self.remote_ud is not None: + self.remote_ud._num_concurrent_uploads = 4 self.last_checkpoint_batch: Optional[Time] = None self.mlflow_loggers = [] @@ -115,7 +113,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None: raise ValueError( f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. ' + f'Got {type(state.model)} instead.') - if self.upload_to_object_store and self.remote_ud is not None: + if self.remote_ud is not None: self.remote_ud.init(state, logger) state.callbacks.append(self.remote_ud) @@ -169,7 +167,7 @@ def _save_checkpoint(self, state: State, logger: Logger): self.huggingface_folder_name_fstr), state.run_name, state.timestamp) dir_context_mgr = tempfile.TemporaryDirectory( - ) if self.upload_to_object_store else contextlib.nullcontext( + ) if self.remote_ud is not None else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: @@ -233,11 +231,8 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility(temp_save_dir) - if self.upload_to_object_store: - assert self.remote_ud is not None - log.info( - f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}' - ) + if self.remote_ud is not None: + log.info(f'Uploading HuggingFace formatted checkpoint') for filename in os.listdir(temp_save_dir): self.remote_ud.upload_file( state=state,