diff --git a/nemo/lightning/io/artifact/__init__.py b/nemo/lightning/io/artifact/__init__.py index 50f77f968a07..a70d39a0873e 100644 --- a/nemo/lightning/io/artifact/__init__.py +++ b/nemo/lightning/io/artifact/__init__.py @@ -1,4 +1,5 @@ from nemo.lightning.io.artifact.base import Artifact from nemo.lightning.io.artifact.file import DirArtifact, DirOrStringArtifact, FileArtifact, PathArtifact +from nemo.lightning.io.artifact.hf_auto import HFAutoArtifact -__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact"] +__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact", "HFAutoArtifact"] diff --git a/nemo/lightning/io/artifact/base.py b/nemo/lightning/io/artifact/base.py index c7243a22af52..fdc4b05f8eac 100644 --- a/nemo/lightning/io/artifact/base.py +++ b/nemo/lightning/io/artifact/base.py @@ -26,7 +26,7 @@ def __init__(self, attr: str, required: bool = True, skip: bool = False): self.skip = skip @abstractmethod - def dump(self, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT: + def dump(self, instance, value: ValueT, absolute_dir: Path, relative_dir: Path) -> ValueT: pass @abstractmethod diff --git a/nemo/lightning/io/artifact/file.py b/nemo/lightning/io/artifact/file.py index 619effbea58f..efebe8694004 100644 --- a/nemo/lightning/io/artifact/file.py +++ b/nemo/lightning/io/artifact/file.py @@ -22,7 +22,7 @@ class PathArtifact(Artifact[Path]): - def dump(self, value: Path, absolute_dir: Path, relative_dir: Path) -> Path: + def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path: new_value = copy_file(value, absolute_dir, relative_dir) return new_value @@ -31,7 +31,7 @@ def load(self, path: Path) -> Path: class FileArtifact(Artifact[str]): - def dump(self, value: str, absolute_dir: Path, relative_dir: Path) -> str: + def dump(self, instance, value: str, absolute_dir: Path, relative_dir: Path) -> str: if not pathize(value).exists(): # This is Artifact is just a string. return fdl.Config(FileArtifact, attr=value, skip=True) diff --git a/nemo/lightning/io/artifact/hf_auto.py b/nemo/lightning/io/artifact/hf_auto.py new file mode 100644 index 000000000000..6e3683aedceb --- /dev/null +++ b/nemo/lightning/io/artifact/hf_auto.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""HuggingFace model serialization support for NeMo's configuration system. + +This module provides integration between NeMo's configuration system and HuggingFace's +pretrained models. It enables automatic serialization and deserialization of HuggingFace +models within NeMo's configuration framework. + +The integration works by: +1. Detecting HuggingFace models through their characteristic methods (save_pretrained/from_pretrained) +2. Converting them to Fiddle configurations that preserve the model's class and path +3. Providing an artifact handler (HFAutoArtifact) that manages the actual model files + +Example: + ```python + from transformers import AutoModel + + # This model will be automatically handled by the HFAutoArtifact system + model = AutoModel.from_pretrained("bert-base-uncased") + + # When serialized, the model files will be saved to the artifacts directory + # When deserialized, the model will be loaded from the saved files + ``` +""" + +import contextlib +import inspect +import threading +from pathlib import Path + +import fiddle as fdl + +from nemo.lightning.io.artifact import Artifact +from nemo.lightning.io.to_config import to_config + +_local = threading.local() + + +class HFAutoArtifact(Artifact): + """Artifact handler for HuggingFace pretrained model/processor/tokenizer/etc.. + + This handler manages the serialization and deserialization of HuggingFace models + by utilizing their save_pretrained/from_pretrained methods. It saves models to + an 'artifacts' subdirectory within the specified path. + """ + + def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path: + """Save a HuggingFace model to disk. + + Args: + instance: The HuggingFace model instance to save + value: Original path value (unused) + absolute_dir: Absolute path to the save directory + relative_dir: Relative path from the config file to the save directory + + Returns: + str: The relative path to the saved model artifacts + """ + instance.save_pretrained(Path(absolute_dir) / "artifacts") + return "./" + str(Path(relative_dir) / "artifacts") + + def load(self, path: Path) -> Path: + """Return the path to load a HuggingFace model. + + Args: + path: Path to the saved model artifacts + + Returns: + Path: The same path, to be used with from_pretrained + """ + return path + + +@contextlib.contextmanager +def from_pretrained_kwargs(**kwargs): + """Context manager for passing additional kwargs to from_pretrained. + + Args: + **kwargs: Keyword arguments to pass to from_pretrained + + Example: + with from_pretrained_kwargs(trust_remote_code=True): + io.load_context("path/to/checkpoint") + """ + if not hasattr(_local, "kwargs"): + _local.kwargs = {} + previous = _local.kwargs.copy() + _local.kwargs.update(kwargs) + try: + yield + finally: + _local.kwargs = previous + + +def from_pretrained(auto_cls, pretrained_model_name_or_path="dummy"): + """Factory function for loading HuggingFace pretrained models. + + This function is used as the serialization target for HuggingFace models. + When deserialized, it will recreate the model using its from_pretrained method. + + Args: + auto_cls: The HuggingFace model class (e.g., AutoModel, AutoTokenizer) + pretrained_model_name_or_path: Path to the saved model or model identifier + + Returns: + The loaded HuggingFace model + """ + kwargs = getattr(_local, "kwargs", {}) + return auto_cls.from_pretrained(pretrained_model_name_or_path, **kwargs) + + +@to_config.register( + lambda v: not inspect.isclass(v) + and getattr(v, "__module__", "").startswith("transformers") + and hasattr(v, "save_pretrained") + and hasattr(v, "from_pretrained") +) +def handle_hf_pretrained(value): + """Convert a HuggingFace model instance to a Fiddle configuration. + + This handler detects HuggingFace model instances by checking for the presence + of save_pretrained and from_pretrained methods. It converts them to a Fiddle + configuration that will recreate the model using from_pretrained. + + Args: + value: A HuggingFace model instance + + Returns: + fdl.Config: A Fiddle configuration that will recreate the model + """ + return fdl.Config( + from_pretrained, + auto_cls=value.__class__, + pretrained_model_name_or_path="dummy", + ) + + +# Register the HFAutoArtifact handler for the pretrained_model_name_or_path parameter +from_pretrained.__io_artifacts__ = [HFAutoArtifact("pretrained_model_name_or_path")] diff --git a/nemo/lightning/io/artifact/pickle.py b/nemo/lightning/io/artifact/pickle.py index 941d69e777a1..b2bf2b4105a4 100644 --- a/nemo/lightning/io/artifact/pickle.py +++ b/nemo/lightning/io/artifact/pickle.py @@ -21,7 +21,7 @@ class PickleArtifact(Artifact[Any]): - def dump(self, absolute_dir: Path, relative_dir: Path) -> Path: + def dump(self, instance, absolute_dir: Path, relative_dir: Path) -> Path: relative_file = self.file_path(relative_dir) with open(Path(absolute_dir) / relative_file, "wb") as f: dump(value, f) diff --git a/nemo/lightning/io/fdl_torch.py b/nemo/lightning/io/fdl_torch.py index a619e4d4d160..72c107674243 100644 --- a/nemo/lightning/io/fdl_torch.py +++ b/nemo/lightning/io/fdl_torch.py @@ -19,9 +19,7 @@ """ import types -from functools import partial -import fiddle as fdl import libcst as cst import torch import torch.nn as nn @@ -29,8 +27,12 @@ from fiddle._src.codegen import import_manager, py_val_to_cst_converter, special_value_codegen from fiddle._src.experimental import serialization +from nemo.lightning.io.artifact import * # noqa: F403 +from nemo.lightning.io.to_config import to_config + def _make_torch_importable(name: str) -> special_value_codegen.Importable: + """Make a torch importable.""" return special_value_codegen.SingleImportable("torch", lambda torch_name: f"{torch_name}.{name}") @@ -67,6 +69,7 @@ def _make_torch_importable(name: str) -> special_value_codegen.Importable: def _make_torch_nn_importable(name: str) -> special_value_codegen.Importable: + """Make a torch.nn importable.""" return special_value_codegen.SingleImportable("torch", lambda torch_mod_name: f"{torch_mod_name}.nn.{name}") @@ -88,6 +91,7 @@ def is_torch_tensor(value): def convert_torch_tensor_to_cst(value, convert_child): + """Convert a PyTorch tensor to a CST node.""" return cst.Call( func=cst.Attribute(value=convert_child(torch), attr=cst.Name("tensor")), args=[ @@ -124,11 +128,10 @@ def enable(): # Monkey-patch the Serialization class to handle things like activation-functions def _modified_serialize(self, value, current_path, all_paths=None): + """Serialize a value to a Fiddle configuration.""" if isinstance(value, types.BuiltinFunctionType): return self._pyref(value, current_path) - if isinstance(value, partial): - value = fdl.Partial(value.func, *value.args, **value.keywords) - return self._original_serialize(value, current_path, all_paths) + return self._original_serialize(to_config(value), current_path, all_paths) serialization.Serialization._original_serialize = serialization.Serialization._serialize serialization.Serialization._serialize = _modified_serialize diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 283cea6943b5..08768f54448c 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -38,6 +38,7 @@ from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.connector import ModelConnector from nemo.lightning.io.fdl_torch import enable as _enable_ext +from nemo.lightning.io.to_config import to_config from nemo.utils import logging ConnT = TypeVar("ConnT", bound=ModelConnector) @@ -233,8 +234,7 @@ def io_dump(self, output: Path, yaml_attrs: list[str]): config_path = output_path / "io.json" with open(config_path, "w") as f: - io = deepcopy(self.__io__) - _artifact_transform_save(io, output_path, local_artifacts_dir) + io = _artifact_transform_save(self, deepcopy(self.__io__), output_path, local_artifacts_dir) json = serialization.dump_json(io) f.write(json) @@ -632,8 +632,10 @@ def _io_path_elements_fn(x): return x.__io__.__path_elements__() -def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: Path = "."): - for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): +def _artifact_transform_save(instance, cfg: fdl.Config, output_path: Path, relative_dir: Path = Path(".")): + artifacts = getattr(cfg.__fn_or_cls__, "__io_artifacts__", []) + + for artifact in artifacts: # Allow optional artifacts if artifact.skip or (not hasattr(cfg, artifact.attr) and not artifact.required): continue @@ -647,16 +649,29 @@ def _artifact_transform_save(cfg: fdl.Config, output_path: Path, relative_dir: P raise ValueError(f"Artifact '{artifact.attr}' is required but not provided") continue ## dump artifact and return the relative path - new_val = artifact.dump(current_val, output_path, relative_dir) + new_val = artifact.dump(instance, current_val, output_path, relative_dir) setattr(cfg, artifact.attr, new_val) for attr in dir(cfg): + child = to_config(getattr(cfg, attr)) + try: - if isinstance(getattr(cfg, attr), fdl.Config): - _artifact_transform_save(getattr(cfg, attr), output_path=output_path, relative_dir=relative_dir) + if isinstance(child, (fdl.Config, fdl.Partial)): + setattr( + cfg, + attr, + _artifact_transform_save( + getattr(instance, attr, None), + child, + output_path=output_path, + relative_dir=relative_dir, + ), + ) except ValueError: pass + return cfg + def _artifact_transform_load(cfg: fdl.Config, path: Path): for artifact in getattr(cfg.__fn_or_cls__, "__io_artifacts__", []): diff --git a/nemo/lightning/io/to_config.py b/nemo/lightning/io/to_config.py new file mode 100644 index 000000000000..539d12af33ed --- /dev/null +++ b/nemo/lightning/io/to_config.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Any, Callable, TypeVar + +import fiddle as fdl + +T = TypeVar("T") + + +class PredicateDispatch: + """A dispatcher that routes values to handlers based on predicates. + + This class implements a predicate-based dispatch system where handlers are registered + with conditions that determine when they should be used. When called, it tries each + predicate in order until it finds a match. + + Example: + ```python + dispatcher = PredicateDispatch() + + @dispatcher.register(lambda x: isinstance(x, str)) + def handle_strings(s): + return f"String: {s}" + + result = dispatcher("hello") # Returns "String: hello" + ``` + """ + + def __init__(self): + """Initialize an empty handler registry.""" + self.handlers = [] + + def register(self, predicate: Callable[[Any], bool]): + """Register a new handler with a predicate. + + Args: + predicate: A callable that takes a value and returns True if the handler + should process this value. + + Returns: + A decorator function that registers the handler. + """ + + def decorator(func: Callable[[T], Any]): + self.handlers.append((predicate, func)) + return func + + return decorator + + def __call__(self, value): + """Process a value through the registered handlers. + + Args: + value: The value to be processed. + + Returns: + The processed value from the first matching handler, or the original value + if no handlers match. + """ + for predicate, handler in self.handlers: + if predicate(value): + return handler(value) + return value # default case: return unchanged + + def register_class(self, cls: type): + """Register a handler for instances of a specific class. + + A convenience method that automatically creates an isinstance predicate. + + Args: + cls: The class to check instances against. + + Returns: + A decorator function that registers the handler. + """ + return self.register(lambda v: isinstance(v, cls)) + + +"""Global dispatcher for converting Python objects to Fiddle configurations. + +This dispatcher is used by Fiddle's serialization system to handle special cases +during configuration serialization. When Fiddle encounters an object it doesn't +know how to serialize, it will pass it through this dispatcher to convert it +into a serializable Fiddle configuration. + +Example use cases: + - Converting functools.partial to fdl.Partial + - Converting HuggingFace models to their from_pretrained configurations + - Handling custom classes with special serialization needs + +The dispatcher is extended by registering new handlers with predicates that +determine when they should be used. See PredicateDispatch for more details. +""" +to_config = PredicateDispatch() + + +@to_config.register_class(partial) +def handle_partial(value: partial): + """Convert functools.partial objects to Fiddle Partial configurations. + + This handler enables serialization of partial function applications by converting + them to Fiddle's equivalent representation. + + Args: + value: A functools.partial object. + + Returns: + A Fiddle Partial configuration representing the same partial application. + """ + return fdl.Partial(value.func, *value.args, **value.keywords)