From f2ba3cd110df6529b72ba6441f12e441739c060c Mon Sep 17 00:00:00 2001 From: Ying Chen Date: Mon, 6 Jan 2025 16:33:54 -0800 Subject: [PATCH] Add stream_name and test --- .../mixing_data_sources.md | 18 +++--- streaming/base/dataset.py | 6 +- tests/test_streaming.py | 55 ++++++++++++++++++- 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/docs/source/dataset_configuration/mixing_data_sources.md b/docs/source/dataset_configuration/mixing_data_sources.md index 64ca9af20..8dd4043df 100644 --- a/docs/source/dataset_configuration/mixing_data_sources.md +++ b/docs/source/dataset_configuration/mixing_data_sources.md @@ -1,3 +1,5 @@ +from matplotlib.artist import kwdoc + # Mixing Datasets Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for a high-level view of distributed training with multiple streams. @@ -14,21 +16,19 @@ You can also customize the implementation of a `Stream`. To modify the behavior ```python from streaming.base.stream import streams_registry -from streaming.base.registry_utils import construct_from_registry class MyStream(Stream): # your implementation goes here pass -# Register your custom stream class as 'stream' -streams_registry.register('stream', func=MyStream) +# Register your custom stream class as 'my_stream' +streams_registry.register('my_stream', func=MyStream) -# StreamingDataset creates a stream instance from the streams_registry -stream = construct_from_registry( - 'stream', - streams_registry, - partial_function=False, - kwargs={'remote': remote, 'local': local} +# StreamingDataset creates a MyStream object when 'my_stream' is passed as a stream_name +dataset = StreamingDataset( + remote='s3://some/path', + local='/local/path', + stream_name='my_stream', ) ``` diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index 475e67d74..247a0397f 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -309,6 +309,9 @@ class StreamingDataset(Array, IterableDataset): replication (int, optional): Determines how many consecutive devices will receive the same samples. Useful for training with tensor or sequence parallelism, where multiple devices need to see the same partition of the dataset. Defaults to ``None``. + stream_name (str): The name of the Stream to use which is registered in streams_registry. + Defaults to ``stream``. + kwargs (any): Additional arguments to pass to the Stream constructor. """ def __init__(self, @@ -336,6 +339,7 @@ def __init__(self, batching_method: str = 'random', allow_unsafe_types: bool = False, replication: Optional[int] = None, + stream_name: str = 'stream', **kwargs: Any) -> None: # Global arguments (which do not live in Streams). self.predownload = predownload @@ -453,7 +457,7 @@ def __init__(self, # Construct a Stream instance using registry-based construction default = construct_from_registry( - name='stream', + name=stream_name, registry=streams_registry, partial_function=False, pre_validation_function=None, diff --git a/tests/test_streaming.py b/tests/test_streaming.py index cd113c6e8..7e3ddc8af 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,13 +5,14 @@ import os import shutil from multiprocessing import Process -from typing import Any +from typing import Any, Optional import pytest from torch.utils.data import DataLoader from streaming.base import Stream, StreamingDataLoader, StreamingDataset from streaming.base.batching import generate_work +from streaming.base.stream import streams_registry from streaming.base.util import clean_stale_shared_memory from streaming.base.world import World from tests.common.utils import convert_to_mds @@ -1053,3 +1054,55 @@ def test_same_local_diff_remote(local_remote_dir: tuple[str, str]): # Build StreamingDataset with pytest.raises(ValueError, match='Reused local directory.*vs.*Provide a different one.'): _ = StreamingDataset(local=local_0, remote=remote_1, batch_size=2, num_canonical_nodes=1) + + +@pytest.mark.usefixtures('local_remote_dir') +def test_custom_stream_name_and_kwargs(local_remote_dir: tuple[str, str]): + remote_dir, local_dir = local_remote_dir + convert_to_mds(out_root=remote_dir, + dataset_name='sequencedataset', + num_samples=117, + size_limit=1 << 8) + + class CustomStream(Stream): + + def __init__( + self, + *, + remote: Optional[str] = None, + local: Optional[str] = None, + split: Optional[str] = None, + proportion: Optional[float] = None, + repeat: Optional[float] = None, + choose: Optional[int] = None, + download_retry: Optional[int] = None, + download_timeout: Optional[float] = None, + validate_hash: Optional[str] = None, + keep_zip: Optional[bool] = None, + **kwargs: Any, + ): + super().__init__( + remote=remote, + local=local, + split=split, + proportion=proportion, + repeat=repeat, + choose=choose, + download_retry=download_retry, + download_timeout=download_timeout, + validate_hash=validate_hash, + keep_zip=keep_zip, + ) + + self.custom_arg = kwargs['custom_arg'] + + streams_registry.register('custom_stream', func=CustomStream) + + dataset = StreamingDataset(local=local_dir, + remote=remote_dir, + stream_name='custom_stream', + custom_arg=100) + + assert len(dataset.streams) == 1 + assert isinstance(dataset.streams[0], CustomStream) + assert dataset.streams[0].custom_arg == 100