Skip to content

Commit

Permalink
Add Unity Catalog support to HF checkpointer (#721)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Nov 9, 2023
1 parent efaa545 commit d2ddb83
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from composer.utils import (dist, format_name_with_dist_and_time,
maybe_create_remote_uploader_downloader_from_uri,
parse_uri)
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase

Expand Down Expand Up @@ -57,8 +58,7 @@ def __init__(
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {
Expand Down Expand Up @@ -93,13 +93,11 @@ def __init__(
self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
if self.upload_to_object_store:
self.remote_ud = RemoteUploaderDownloader(
bucket_uri=f'{self.backend}://{self.bucket_name}',
num_concurrent_uploads=4)
else:
self.remote_ud = None

self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
self.remote_ud._num_concurrent_uploads = 4

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []
Expand All @@ -115,7 +113,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
raise ValueError(
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
+ f'Got {type(state.model)} instead.')
if self.upload_to_object_store and self.remote_ud is not None:
if self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

Expand Down Expand Up @@ -169,7 +167,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
) if self.upload_to_object_store else contextlib.nullcontext(
) if self.remote_ud is not None else contextlib.nullcontext(
enter_result=save_dir)

with dir_context_mgr as temp_save_dir:
Expand Down Expand Up @@ -233,11 +231,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

if self.upload_to_object_store:
assert self.remote_ud is not None
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
for filename in os.listdir(temp_save_dir):
self.remote_ud.upload_file(
state=state,
Expand Down

0 comments on commit d2ddb83

Please sign in to comment.