Skip to content

Commit

Permalink
Add generic flatten imports to HF checkpointer (#814)
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu authored Dec 21, 2023
1 parent 289536b commit bbf5cc7
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 44 deletions.
57 changes: 35 additions & 22 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Sequence, Union

import torch
from composer.core import Callback, Event, State, Time, TimeUnit
Expand All @@ -32,31 +32,40 @@ class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Args:
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
save_folder (str): Top level folder to save checkpoints to (can be a
URI). It is likely that this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often
checkpoints should be saved. If an integer, it will be assumed to be
in :attr:`.TimeUnit.EPOCH`. Otherwise, the unit must be either
:attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
huggingface_folder_name (str): Folder to save each checkpoint under (can
be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``.
Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
mlflow_registered_model_name (Optional[str]): The name to register the
model under in the MLflow model registry. If ``None``, the model
will not be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments
that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is
unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
"""

def __init__(
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = True,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = True,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -66,6 +75,7 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]
self.flatten_imports = flatten_imports

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
Expand All @@ -91,7 +101,7 @@ def __init__(
if isinstance(save_interval, int):
save_interval = Time(save_interval, TimeUnit.EPOCH)

self.save_interval = save_interval
self.save_interval: Time = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
Expand Down Expand Up @@ -229,7 +239,10 @@ def _save_checkpoint(self, state: State, logger: Logger):
# Only need to edit files for MPT because it has custom code
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
log.info(f'Uploading HuggingFace formatted checkpoint')
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
# Otherwise, certain modules are missing.
# isort: off
from llmfoundry.models.utils.adapt_tokenizer import (
AutoTokenizerForMOD, # type: ignore (see note),
AutoTokenizerForMOD, # type: ignore (see note)
adapt_tokenizer_for_denoising, # type: ignore (see note)
)
from llmfoundry.models.utils.hf_prefixlm_converter import (
Expand Down
62 changes: 41 additions & 21 deletions llmfoundry/utils/huggingface_hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import ast
import importlib
import os
from typing import List, Optional
from typing import Optional, Sequence

__all__ = ['edit_files_for_hf_compatibility']


class DeleteSpecificNodes(ast.NodeTransformer):

def __init__(self, nodes_to_remove: List[ast.AST]):
def __init__(self, nodes_to_remove: list[ast.AST]):
self.nodes_to_remove = nodes_to_remove

def visit(self, node: ast.AST) -> Optional[ast.AST]:
Expand Down Expand Up @@ -39,7 +39,26 @@ def find_module_file(module_name: str) -> str:
return module_file


def process_file(file_path: str, folder_path: str) -> List[str]:
def _flatten_import(
node: ast.ImportFrom,
flatten_imports_prefix: Sequence[str],
) -> bool:
"""Returns True if import should be flattened.
Checks whether the node starts the same as any of the imports in
flatten_imports_prefix.
"""
for import_prefix in flatten_imports_prefix:
if node.module is not None and node.module.startswith(import_prefix):
return True
return False


def process_file(
file_path: str,
folder_path: str,
flatten_imports_prefix: Sequence[str],
) -> list[str]:
with open(file_path, 'r') as f:
source = f.read()

Expand All @@ -51,37 +70,35 @@ def process_file(file_path: str, folder_path: str) -> List[str]:
new_files_to_process = []
nodes_to_remove = []
for node in ast.walk(tree):
# convert any llmfoundry imports into relative imports
if isinstance(
node, ast.ImportFrom
) and node.module is not None and node.module.startswith('llmfoundry'):
# Convert any llmfoundry imports into relative imports
if (isinstance(node, ast.ImportFrom) and node.module is not None and
_flatten_import(node, flatten_imports_prefix)):
module_path = find_module_file(node.module)
node.module = convert_to_relative_import(node.module,
parent_module_name)
# recursively process any llmfoundry files
# Recursively process any llmfoundry files
new_files_to_process.append(module_path)
# remove any imports from composer or omegaconf
# Remove any imports from composer or omegaconf
elif isinstance(node, ast.ImportFrom) and node.module is not None and (
node.module.startswith('composer') or
node.module.startswith('omegaconf')):
nodes_to_remove.append(node)
# remove the Composer* class
elif isinstance(node,
ast.ClassDef) and node.name.startswith('Composer'):
# Remove the Composer* class
elif (isinstance(node, ast.ClassDef) and
node.name.startswith('Composer')):
nodes_to_remove.append(node)
# remove the __all__ declaration in any __init__.py files, whose enclosing module
# will be converted to a single file of the same name
elif isinstance(node,
ast.Assign) and len(node.targets) == 1 and isinstance(
node.targets[0],
ast.Name) and node.targets[0].id == '__all__':
# Remove the __all__ declaration in any __init__.py files, whose
# enclosing module will be converted to a single file of the same name
elif (isinstance(node, ast.Assign) and len(node.targets) == 1 and
isinstance(node.targets[0], ast.Name) and
node.targets[0].id == '__all__'):
nodes_to_remove.append(node)

transformer = DeleteSpecificNodes(nodes_to_remove)
new_tree = transformer.visit(tree)

new_filename = os.path.basename(file_path)
# special case for __init__.py to mimic the original submodule
# Special case for __init__.py to mimic the original submodule
if new_filename == '__init__.py':
new_filename = file_path.split('/')[-2] + '.py'
new_file_path = os.path.join(folder_path, new_filename)
Expand All @@ -92,7 +109,10 @@ def process_file(file_path: str, folder_path: str) -> List[str]:
return new_files_to_process


def edit_files_for_hf_compatibility(folder: str) -> None:
def edit_files_for_hf_compatibility(
folder: str,
flatten_imports_prefix: Sequence[str] = ('llmfoundry',),
) -> None:
files_to_process = [
os.path.join(folder, filename)
for filename in os.listdir(folder)
Expand All @@ -103,7 +123,7 @@ def edit_files_for_hf_compatibility(folder: str) -> None:
while len(files_to_process) > 0:
to_process = files_to_process.pop()
if os.path.isfile(to_process) and to_process.endswith('.py'):
to_add = process_file(to_process, folder)
to_add = process_file(to_process, folder, flatten_imports_prefix)
for file in to_add:
if file not in files_processed_and_queued:
files_to_process.append(file)
Expand Down
16 changes: 16 additions & 0 deletions tests/utils/test_huggingface_hub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import ast

from llmfoundry.utils.huggingface_hub_utils import _flatten_import


def test_flatten_import_true():
node = ast.ImportFrom('y', ['x', 'y', 'z'])
assert _flatten_import(node, ('x', 'y', 'z'))


def test_flatten_import_false():
node = ast.ImportFrom('y', ['x', 'y', 'z'])
assert not _flatten_import(node, ('x', 'z'))

0 comments on commit bbf5cc7

Please sign in to comment.