Skip to content

Commit

Permalink
Add stream_name and test
Browse files Browse the repository at this point in the history
  • Loading branch information
es94129 committed Jan 7, 2025
1 parent 8c7a176 commit f2ba3cd
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 11 deletions.
18 changes: 9 additions & 9 deletions docs/source/dataset_configuration/mixing_data_sources.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from matplotlib.artist import kwdoc

# Mixing Datasets

Training a model often requires combining data from multiple different sources. Streaming makes combining these data sources, or streams, easy and configurable. See the [main concepts page](../getting_started/main_concepts.md#distributed-model-training) for a high-level view of distributed training with multiple streams.
Expand All @@ -14,21 +16,19 @@ You can also customize the implementation of a `Stream`. To modify the behavior
<!--pytest.mark.skip-->
```python
from streaming.base.stream import streams_registry
from streaming.base.registry_utils import construct_from_registry

class MyStream(Stream):
# your implementation goes here
pass

# Register your custom stream class as 'stream'
streams_registry.register('stream', func=MyStream)
# Register your custom stream class as 'my_stream'
streams_registry.register('my_stream', func=MyStream)

# StreamingDataset creates a stream instance from the streams_registry
stream = construct_from_registry(
'stream',
streams_registry,
partial_function=False,
kwargs={'remote': remote, 'local': local}
# StreamingDataset creates a MyStream object when 'my_stream' is passed as a stream_name
dataset = StreamingDataset(
remote='s3://some/path',
local='/local/path',
stream_name='my_stream',
)
```

Expand Down
6 changes: 5 additions & 1 deletion streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ class StreamingDataset(Array, IterableDataset):
replication (int, optional): Determines how many consecutive devices will receive the same
samples. Useful for training with tensor or sequence parallelism, where multiple
devices need to see the same partition of the dataset. Defaults to ``None``.
stream_name (str): The name of the Stream to use which is registered in streams_registry.
Defaults to ``stream``.
kwargs (any): Additional arguments to pass to the Stream constructor.
"""

def __init__(self,
Expand Down Expand Up @@ -336,6 +339,7 @@ def __init__(self,
batching_method: str = 'random',
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
stream_name: str = 'stream',
**kwargs: Any) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
Expand Down Expand Up @@ -453,7 +457,7 @@ def __init__(self,

# Construct a Stream instance using registry-based construction
default = construct_from_registry(
name='stream',
name=stream_name,
registry=streams_registry,
partial_function=False,
pre_validation_function=None,
Expand Down
55 changes: 54 additions & 1 deletion tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import os
import shutil
from multiprocessing import Process
from typing import Any
from typing import Any, Optional

import pytest
from torch.utils.data import DataLoader

from streaming.base import Stream, StreamingDataLoader, StreamingDataset
from streaming.base.batching import generate_work
from streaming.base.stream import streams_registry
from streaming.base.util import clean_stale_shared_memory
from streaming.base.world import World
from tests.common.utils import convert_to_mds
Expand Down Expand Up @@ -1053,3 +1054,55 @@ def test_same_local_diff_remote(local_remote_dir: tuple[str, str]):
# Build StreamingDataset
with pytest.raises(ValueError, match='Reused local directory.*vs.*Provide a different one.'):
_ = StreamingDataset(local=local_0, remote=remote_1, batch_size=2, num_canonical_nodes=1)


@pytest.mark.usefixtures('local_remote_dir')
def test_custom_stream_name_and_kwargs(local_remote_dir: tuple[str, str]):
remote_dir, local_dir = local_remote_dir
convert_to_mds(out_root=remote_dir,
dataset_name='sequencedataset',
num_samples=117,
size_limit=1 << 8)

class CustomStream(Stream):

def __init__(
self,
*,
remote: Optional[str] = None,
local: Optional[str] = None,
split: Optional[str] = None,
proportion: Optional[float] = None,
repeat: Optional[float] = None,
choose: Optional[int] = None,
download_retry: Optional[int] = None,
download_timeout: Optional[float] = None,
validate_hash: Optional[str] = None,
keep_zip: Optional[bool] = None,
**kwargs: Any,
):
super().__init__(
remote=remote,
local=local,
split=split,
proportion=proportion,
repeat=repeat,
choose=choose,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
)

self.custom_arg = kwargs['custom_arg']

streams_registry.register('custom_stream', func=CustomStream)

dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
stream_name='custom_stream',
custom_arg=100)

assert len(dataset.streams) == 1
assert isinstance(dataset.streams[0], CustomStream)
assert dataset.streams[0].custom_arg == 100

0 comments on commit f2ba3cd

Please sign in to comment.