Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding serialization to all Auto* objects in HuggingFace transformers #11645

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion nemo/lightning/io/artifact/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nemo.lightning.io.artifact.base import Artifact
from nemo.lightning.io.artifact.file import DirArtifact, DirOrStringArtifact, FileArtifact, PathArtifact
from nemo.lightning.io.artifact.hf_auto import HFAutoArtifact

__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact"]
__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact", "HFAutoArtifact"]
2 changes: 1 addition & 1 deletion nemo/lightning/io/artifact/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, attr: str, required: bool = True, skip: bool = False):
self.skip = skip

@abstractmethod
def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
def dump(self, instance, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT:
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions nemo/lightning/io/artifact/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class PathArtifact(Artifact[Path]):
def dump(self, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
new_value = copy_file(value, absolute_dir, relative_dir)
return new_value

Expand All @@ -31,7 +31,7 @@ def load(self, path: Path) -> Path:


class FileArtifact(Artifact[str]):
def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str:
def dump(self, instance, value: str, absolute_dir: Path, relative_dir: Path) -> str:
if not pathize(value).exists():
# This is Artifact is just a string.
return fdl.Config(FileArtifact, attr=value, skip=True)
Expand Down
152 changes: 152 additions & 0 deletions nemo/lightning/io/artifact/hf_auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""HuggingFace model serialization support for NeMo's configuration system.

This module provides integration between NeMo's configuration system and HuggingFace's
pretrained models. It enables automatic serialization and deserialization of HuggingFace
models within NeMo's configuration framework.

The integration works by:
1. Detecting HuggingFace models through their characteristic methods (save_pretrained/from_pretrained)
2. Converting them to Fiddle configurations that preserve the model's class and path
3. Providing an artifact handler (HFAutoArtifact) that manages the actual model files

Example:
```python
from transformers import AutoModel

# This model will be automatically handled by the HFAutoArtifact system
model = AutoModel.from_pretrained("bert-base-uncased")

# When serialized, the model files will be saved to the artifacts directory
# When deserialized, the model will be loaded from the saved files
```
"""

import contextlib
import inspect
import threading
from pathlib import Path

import fiddle as fdl

from nemo.lightning.io.artifact import Artifact
from nemo.lightning.io.to_config import to_config

_local = threading.local()


class HFAutoArtifact(Artifact):
"""Artifact handler for HuggingFace pretrained model/processor/tokenizer/etc..

This handler manages the serialization and deserialization of HuggingFace models
by utilizing their save_pretrained/from_pretrained methods. It saves models to
an 'artifacts' subdirectory within the specified path.
"""

def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path:
"""Save a HuggingFace model to disk.

Args:
instance: The HuggingFace model instance to save
value: Original path value (unused)
absolute_dir: Absolute path to the save directory
relative_dir: Relative path from the config file to the save directory

Returns:
str: The relative path to the saved model artifacts
"""
instance.save_pretrained(Path(absolute_dir) / "artifacts")
return "./" + str(Path(relative_dir) / "artifacts")

def load(self, path: Path) -> Path:
"""Return the path to load a HuggingFace model.

Args:
path: Path to the saved model artifacts

Returns:
Path: The same path, to be used with from_pretrained
"""
return path


@contextlib.contextmanager
def from_pretrained_kwargs(**kwargs):
"""Context manager for passing additional kwargs to from_pretrained.

Args:
**kwargs: Keyword arguments to pass to from_pretrained

Example:
with from_pretrained_kwargs(trust_remote_code=True):
io.load_context("path/to/checkpoint")
"""
if not hasattr(_local, "kwargs"):
_local.kwargs = {}
previous = _local.kwargs.copy()
_local.kwargs.update(kwargs)
try:
yield
finally:
_local.kwargs = previous


def from_pretrained(auto_cls, pretrained_model_name_or_path="dummy"):
"""Factory function for loading HuggingFace pretrained models.

This function is used as the serialization target for HuggingFace models.
When deserialized, it will recreate the model using its from_pretrained method.

Args:
auto_cls: The HuggingFace model class (e.g., AutoModel, AutoTokenizer)
pretrained_model_name_or_path: Path to the saved model or model identifier

Returns:
The loaded HuggingFace model
"""
kwargs = getattr(_local, "kwargs", {})
return auto_cls.from_pretrained(pretrained_model_name_or_path, **kwargs)


@to_config.register(
lambda v: not inspect.isclass(v)
and getattr(v, "__module__", "").startswith("transformers")
and hasattr(v, "save_pretrained")
and hasattr(v, "from_pretrained")
)
def handle_hf_pretrained(value):
"""Convert a HuggingFace model instance to a Fiddle configuration.

This handler detects HuggingFace model instances by checking for the presence
of save_pretrained and from_pretrained methods. It converts them to a Fiddle
configuration that will recreate the model using from_pretrained.

Args:
value: A HuggingFace model instance

Returns:
fdl.Config: A Fiddle configuration that will recreate the model
"""
return fdl.Config(
from_pretrained,
auto_cls=value.__class__,
pretrained_model_name_or_path="dummy",
)


# Register the HFAutoArtifact handler for the pretrained_model_name_or_path parameter
from_pretrained.__io_artifacts__ = [HFAutoArtifact("pretrained_model_name_or_path")]
2 changes: 1 addition & 1 deletion nemo/lightning/io/artifact/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class PickleArtifact(Artifact[Any]):
def dump(self, absolute_dir: Path, relative_dir: Path) -> Path:
def dump(self, instance, absolute_dir: Path, relative_dir: Path) -> Path:
relative_file = self.file_path(relative_dir)
with open(Path(absolute_dir) / relative_file, "wb") as f:
dump(value, f)
Expand Down
13 changes: 8 additions & 5 deletions nemo/lightning/io/fdl_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@
"""

import types
from functools import partial

import fiddle as fdl
import libcst as cst
import torch
import torch.nn as nn
from fiddle._src import daglish_extensions
from fiddle._src.codegen import import_manager, py_val_to_cst_converter, special_value_codegen
from fiddle._src.experimental import serialization

from nemo.lightning.io.artifact import * # noqa: F403
from nemo.lightning.io.to_config import to_config


def _make_torch_importable(name: str) -> special_value_codegen.Importable:
"""Make a torch importable."""
return special_value_codegen.SingleImportable("torch", lambda torch_name: f"{torch_name}.{name}")


Expand Down Expand Up @@ -67,6 +69,7 @@ def _make_torch_importable(name: str) -> special_value_codegen.Importable:


def _make_torch_nn_importable(name: str) -> special_value_codegen.Importable:
"""Make a torch.nn importable."""
return special_value_codegen.SingleImportable("torch", lambda torch_mod_name: f"{torch_mod_name}.nn.{name}")


Expand All @@ -88,6 +91,7 @@ def is_torch_tensor(value):


def convert_torch_tensor_to_cst(value, convert_child):
"""Convert a PyTorch tensor to a CST node."""
return cst.Call(
func=cst.Attribute(value=convert_child(torch), attr=cst.Name("tensor")),
args=[
Expand Down Expand Up @@ -124,11 +128,10 @@ def enable():

# Monkey-patch the Serialization class to handle things like activation-functions
def _modified_serialize(self, value, current_path, all_paths=None):
"""Serialize a value to a Fiddle configuration."""
if isinstance(value, types.BuiltinFunctionType):
return self._pyref(value, current_path)
if isinstance(value, partial):
value = fdl.Partial(value.func, *value.args, **value.keywords)
return self._original_serialize(value, current_path, all_paths)
return self._original_serialize(to_config(value), current_path, all_paths)

serialization.Serialization._original_serialize = serialization.Serialization._serialize
serialization.Serialization._serialize = _modified_serialize
29 changes: 22 additions & 7 deletions nemo/lightning/io/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from nemo.lightning.io.capture import IOProtocol
from nemo.lightning.io.connector import ModelConnector
from nemo.lightning.io.fdl_torch import enable as _enable_ext
from nemo.lightning.io.to_config import to_config
from nemo.utils import logging

ConnT = TypeVar("ConnT", bound=ModelConnector)
Expand Down Expand Up @@ -233,8 +234,7 @@ def io_dump(self, output: Path, yaml_attrs: list[str]):

config_path = output_path / "io.json"
with open(config_path, "w") as f:
io = deepcopy(self.__io__)
_artifact_transform_save(io, output_path, local_artifacts_dir)
io = _artifact_transform_save(self, deepcopy(self.__io__), output_path, local_artifacts_dir)
json = serialization.dump_json(io)
f.write(json)

Expand Down Expand Up @@ -632,8 +632,10 @@ def _io_path_elements_fn(x):
return x.__io__.__path_elements__()


def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
def _artifact_transform_save(instance, cfg: fdl.Config, output_path: Path, relative_dir: Path = Path(".")):
artifacts = getattr(cfg.__fn_or_cls__, "__io_artifacts__", [])

for artifact in artifacts:
# Allow optional artifacts
if artifact.skip or (not hasattr(cfg, artifact.attr) and not artifact.required):
continue
Expand All @@ -647,16 +649,29 @@ def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: P
raise ValueError(f"Artifact '{artifact.attr}' is required but not provided")
continue
## dump artifact and return the relative path
new_val = artifact.dump(current_val, output_path, relative_dir)
new_val = artifact.dump(instance, current_val, output_path, relative_dir)
marcromeyn marked this conversation as resolved.
Show resolved Hide resolved
setattr(cfg, artifact.attr, new_val)

for attr in dir(cfg):
child = to_config(getattr(cfg, attr))

try:
if isinstance(getattr(cfg, attr), fdl.Config):
_artifact_transform_save(getattr(cfg, attr), output_path=output_path, relative_dir=relative_dir)
if isinstance(child, (fdl.Config, fdl.Partial)):
setattr(
cfg,
attr,
_artifact_transform_save(
getattr(instance, attr, None),
child,
output_path=output_path,
relative_dir=relative_dir,
),
)
except ValueError:
pass

return cfg


def _artifact_transform_load(cfg: fdl.Config, path: Path):
for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []):
Expand Down
Loading
Loading