diff --git a/setup.py b/setup.py index 0d7cbd134..dc2fa13ec 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ 'azure-storage-blob>=12.0.0,<13', 'azure-storage-file-datalake>=12.11.0,<13', 'azure-identity>=1.13.0', + 'catalogue>=2,<3', ] extra_deps = {} diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index cb5c32ba9..447ca5b93 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -34,7 +34,8 @@ from streaming.base.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path, get_shm_prefix) from streaming.base.spanner import Spanner -from streaming.base.stream import Stream +from streaming.base.registry_utils import construct_from_registry +from streaming.base.stream import Stream, streams_registry from streaming.base.util import bytes_to_int, number_abbrev_to_int from streaming.base.world import World @@ -334,7 +335,8 @@ def __init__(self, shuffle_block_size: Optional[int] = None, batching_method: str = 'random', allow_unsafe_types: bool = False, - replication: Optional[int] = None) -> None: + replication: Optional[int] = None, + **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload self.cache_limit = cache_limit @@ -438,13 +440,26 @@ def __init__(self, for stream in streams: stream.apply_default(default) else: - default = Stream(remote=remote, - local=local, - split=split, - download_retry=download_retry, - download_timeout=download_timeout, - validate_hash=validate_hash, - keep_zip=keep_zip) + kwargs = { + 'remote': remote, + 'local': local, + 'split': split, + 'download_retry': download_retry, + 'download_timeout': download_timeout, + 'validate_hash': validate_hash, + 'keep_zip': keep_zip, + **kwargs, + } + + default = construct_from_registry( + name='stream', + registry=streams_registry, + partial_function=False, + pre_validation_function=None, + post_validation_function=None, + kwargs=kwargs, + ) + streams = [default] # Validate the stream weighting scheme (relative or absolute) to catch errors before we go diff --git a/streaming/base/registry_utils.py b/streaming/base/registry_utils.py new file mode 100644 index 000000000..fa6ca44a8 --- /dev/null +++ b/streaming/base/registry_utils.py @@ -0,0 +1,200 @@ +# Copyright 2024 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +import copy +import functools +import importlib.util +import os +from contextlib import contextmanager +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Generic, Optional, Sequence, TypeVar, Union + +import catalogue + +__all__ = [ + 'TypedRegistry', + 'create_registry', + 'construct_from_registry', + 'import_file', + 'save_registry', +] + +T = TypeVar('T') +TypeBoundT = TypeVar('TypeBoundT', bound=type) +CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any]) + + +class TypedRegistry(catalogue.Registry, Generic[T]): + """A thin wrapper around catalogue.Registry to add static typing and. + + descriptions. + """ + + def __init__( + self, + namespace: Sequence[str], + entry_points: bool = False, + description: str = '', + ) -> None: + super().__init__(namespace, entry_points=entry_points) + + self.description = description + + def __call__(self, name: str, func: Optional[T] = None) -> Callable[[T], T]: + return super().__call__(name, func) + + def register(self, name: str, *, func: Optional[T] = None) -> T: + return super().register(name, func=func) + + def register_class( + self, + name: str, + *, + func: Optional[TypeBoundT] = None, + ) -> TypeBoundT: + return super().register(name, func=func) + + def get(self, name: str) -> T: + return super().get(name) + + def get_all(self) -> dict[str, T]: + return super().get_all() + + def get_entry_point(self, name: str, default: Optional[T] = None) -> T: + return super().get_entry_point(name, default=default) + + def get_entry_points(self) -> dict[str, T]: + return super().get_entry_points() + + +S = TypeVar('S') + + +def create_registry( + *namespace: str, + generic_type: type[S], + entry_points: bool = False, + description: str = '', +) -> 'TypedRegistry[S]': + """Create a new registry. + + Args: + namespace (str): The namespace, e.g. "llmfoundry.loggers" + generic_type (Type[S]): The type of the registry. + entry_points (bool): Accept registered functions from entry points. + description (str): A description of the registry. + + Returns: + The TypedRegistry object. + """ + if catalogue.check_exists(*namespace): + raise catalogue.RegistryError(f'Namespace already exists: {namespace}') + + return TypedRegistry[generic_type]( + namespace, + entry_points=entry_points, + description=description, + ) + + +def construct_from_registry( + name: str, + registry: TypedRegistry, + partial_function: bool = True, + pre_validation_function: Optional[Union[Callable[[Any], None], type]] = None, + post_validation_function: Optional[Callable[[Any], None]] = None, + kwargs: Optional[dict[str, Any]] = None, +) -> Any: + """Helper function to build an item from the registry. + + Args: + name (str): The name of the registered item + registry (catalogue.Registry): The registry to fetch the item from + partial_function (bool, optional): Whether to return a partial function for registered callables. Defaults to True. + pre_validation_function (Optional[Union[Callable[[Any], None], type]], optional): An optional validation function called + before constructing the item to return. This should throw an exception if validation fails. Defaults to None. + post_validation_function (Optional[Callable[[Any], None]], optional): An optional validation function called after + constructing the item to return. This should throw an exception if validation fails. Defaults to None. + kwargs (Optional[Dict[str, Any]]): Other relevant keyword arguments. + + Raises: + ValueError: If the validation functions failed or the registered item is invalid + + Returns: + Any: The constructed item from the registry + """ + if kwargs is None: + kwargs = {} + + registered_constructor = registry.get(name) + + if pre_validation_function is not None: + if isinstance(pre_validation_function, type): + if not issubclass(registered_constructor, pre_validation_function): + raise ValueError( + f'Expected {name} to be of type {pre_validation_function}, but got {type(registered_constructor)}', + ) + elif isinstance(pre_validation_function, Callable): + pre_validation_function(registered_constructor) + else: + raise ValueError( + f'Expected pre_validation_function to be a callable or a type, but got {type(pre_validation_function)}', + ) + + # If it is a class, or a builder function, construct the class with kwargs + # If it is a function, create a partial with kwargs + if isinstance( + registered_constructor, + type, + ) or callable(registered_constructor) and not partial_function: + constructed_item = registered_constructor(**kwargs) + elif callable(registered_constructor): + constructed_item = functools.partial(registered_constructor, **kwargs) + else: + raise ValueError( + f'Expected {name} to be a class or function, but got {type(registered_constructor)}',) + + if post_validation_function is not None: + post_validation_function(constructed_item) + + return constructed_item + + +def import_file(loc: Union[str, Path]) -> ModuleType: + """Import module from a file. + + Used to run arbitrary python code. + + Args: + name (str): Name of module to load. + loc (str / Path): Path to the file. + + Returns: + ModuleType: The module object. + """ + if not os.path.exists(loc): + raise FileNotFoundError(f'File {loc} does not exist.') + + spec = importlib.util.spec_from_file_location('python_code', str(loc)) + + assert spec is not None + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f'Error executing {loc}') from e + return module + + +@contextmanager +def save_registry(): + """Save the registry state and restore after the context manager exits.""" + saved_registry_state = copy.deepcopy(catalogue.REGISTRY) + + yield + + catalogue.REGISTRY = saved_registry_state diff --git a/streaming/base/stream.py b/streaming/base/stream.py index 133938e12..b1e1dd05f 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -21,7 +21,7 @@ from streaming.base.storage import CloudDownloader from streaming.base.util import retry, wait_for_file_to_exist from streaming.base.world import World - +from streaming.base.registry_utils import create_registry class Stream: """A dataset, or sub-dataset if mixing, from which we stream/cache samples. @@ -507,3 +507,10 @@ def get_index_size(self) -> int: """ filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size + +streams_registry = create_registry('streaming', 'streams_registry', + generic_type=type[Stream], + entry_points=True, + description="The streams registry is used for registering Stream classes.") + +streams_registry.register('stream', func=Stream) \ No newline at end of file