-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use registry when creating Stream in StreamingDataset
- Loading branch information
Showing
4 changed files
with
233 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# Copyright 2024 MosaicML Streaming authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import copy | ||
import functools | ||
import importlib.util | ||
import os | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
from types import ModuleType | ||
from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union | ||
|
||
import catalogue | ||
|
||
__all__ = [ | ||
'TypedRegistry', | ||
'create_registry', | ||
'construct_from_registry', | ||
'import_file', | ||
'save_registry', | ||
] | ||
|
||
T = TypeVar('T') | ||
TypeBoundT = TypeVar('TypeBoundT', bound=type) | ||
CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) | ||
|
||
|
||
class TypedRegistry(catalogue.Registry, Generic[T]): | ||
"""A thin wrapper around catalogue.Registry to add static typing and. | ||
descriptions. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
namespace: Sequence[str], | ||
entry_points: bool = False, | ||
description: str = '', | ||
) -> None: | ||
super().__init__(namespace, entry_points=entry_points) | ||
|
||
self.description = description | ||
|
||
def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]: | ||
return super().__call__(name, func) | ||
|
||
def register(self, name: str, *, func: Optional[T] = None) -> T: | ||
return super().register(name, func=func) | ||
|
||
def register_class( | ||
self, | ||
name: str, | ||
*, | ||
func: Optional[TypeBoundT] = None, | ||
) -> TypeBoundT: | ||
return super().register(name, func=func) | ||
|
||
def get(self, name: str) -> T: | ||
return super().get(name) | ||
|
||
def get_all(self) -> dict[str, T]: | ||
return super().get_all() | ||
|
||
def get_entry_point(self, name: str, default: Optional[T] = None) -> T: | ||
return super().get_entry_point(name, default=default) | ||
|
||
def get_entry_points(self) -> dict[str, T]: | ||
return super().get_entry_points() | ||
|
||
|
||
S = TypeVar('S') | ||
|
||
|
||
def create_registry( | ||
*namespace: str, | ||
generic_type: type[S], | ||
entry_points: bool = False, | ||
description: str = '', | ||
) -> 'TypedRegistry[S]': | ||
"""Create a new registry. | ||
Args: | ||
namespace (str): The namespace, e.g. "llmfoundry.loggers" | ||
generic_type (Type[S]): The type of the registry. | ||
entry_points (bool): Accept registered functions from entry points. | ||
description (str): A description of the registry. | ||
Returns: | ||
The TypedRegistry object. | ||
""" | ||
if catalogue.check_exists(*namespace): | ||
raise catalogue.RegistryError(f'Namespace already exists: {namespace}') | ||
|
||
return TypedRegistry[generic_type]( | ||
namespace, | ||
entry_points=entry_points, | ||
description=description, | ||
) | ||
|
||
|
||
def construct_from_registry( | ||
name: str, | ||
registry: TypedRegistry, | ||
partial_function: bool = True, | ||
pre_validation_function: Optional[Union[Callable[[Any], None], type]] = None, | ||
post_validation_function: Optional[Callable[[Any], None]] = None, | ||
kwargs: Optional[dict[str, Any]] = None, | ||
) -> Any: | ||
"""Helper function to build an item from the registry. | ||
Args: | ||
name (str): The name of the registered item | ||
registry (catalogue.Registry): The registry to fetch the item from | ||
partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True. | ||
pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called | ||
before constructing the item to return. This should throw an exception if validation fails. Defaults to None. | ||
post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after | ||
constructing the item to return. This should throw an exception if validation fails. Defaults to None. | ||
kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. | ||
Raises: | ||
ValueError: If the validation functions failed or the registered item is invalid | ||
Returns: | ||
Any: The constructed item from the registry | ||
""" | ||
if kwargs is None: | ||
kwargs = {} | ||
|
||
registered_constructor = registry.get(name) | ||
|
||
if pre_validation_function is not None: | ||
if isinstance(pre_validation_function, type): | ||
if not issubclass(registered_constructor, pre_validation_function): | ||
raise ValueError( | ||
f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}', | ||
) | ||
elif isinstance(pre_validation_function, Callable): | ||
pre_validation_function(registered_constructor) | ||
else: | ||
raise ValueError( | ||
f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}', | ||
) | ||
|
||
# If it is a class, or a builder function, construct the class with kwargs | ||
# If it is a function, create a partial with kwargs | ||
if isinstance( | ||
registered_constructor, | ||
type, | ||
) or callable(registered_constructor) and not partial_function: | ||
constructed_item = registered_constructor(**kwargs) | ||
elif callable(registered_constructor): | ||
constructed_item = functools.partial(registered_constructor, **kwargs) | ||
else: | ||
raise ValueError( | ||
f'Expected {name} to be a class or function, but got {type(registered_constructor)}',) | ||
|
||
if post_validation_function is not None: | ||
post_validation_function(constructed_item) | ||
|
||
return constructed_item | ||
|
||
|
||
def import_file(loc: Union[str, Path]) -> ModuleType: | ||
"""Import module from a file. | ||
Used to run arbitrary python code. | ||
Args: | ||
name (str): Name of module to load. | ||
loc (str / Path): Path to the file. | ||
Returns: | ||
ModuleType: The module object. | ||
""" | ||
if not os.path.exists(loc): | ||
raise FileNotFoundError(f'File {loc} does not exist.') | ||
|
||
spec = importlib.util.spec_from_file_location('python_code', str(loc)) | ||
|
||
assert spec is not None | ||
assert spec.loader is not None | ||
|
||
module = importlib.util.module_from_spec(spec) | ||
|
||
try: | ||
spec.loader.exec_module(module) | ||
except Exception as e: | ||
raise RuntimeError(f'Error executing {loc}') from e | ||
return module | ||
|
||
|
||
@contextmanager | ||
def save_registry(): | ||
"""Save the registry state and restore after the context manager exits.""" | ||
saved_registry_state = copy.deepcopy(catalogue.REGISTRY) | ||
|
||
yield | ||
|
||
catalogue.REGISTRY = saved_registry_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters