Skip to content

Commit

Permalink
Use stream_config instead of kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
es94129 committed Jan 7, 2025
1 parent 607bd9c commit 52de4dc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/source/dataset_configuration/mixing_data_sources.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dataset = StreamingDataset(
remote='s3://some/path',
local='/local/path',
stream_name='my_stream',
stream_config={'arg1': 'val1'},
)
```

Expand Down
12 changes: 6 additions & 6 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class StreamingDataset(Array, IterableDataset):
devices need to see the same partition of the dataset. Defaults to ``None``.
stream_name (str): The name of the Stream to use which is registered in streams_registry.
Defaults to ``stream``.
kwargs (any): Additional arguments to pass to the Stream constructor.
stream_config (dict[str, any]): Additional arguments to pass to the Stream constructor.
"""

def __init__(self,
Expand Down Expand Up @@ -340,7 +340,7 @@ def __init__(self,
allow_unsafe_types: bool = False,
replication: Optional[int] = None,
stream_name: str = 'stream',
**kwargs: Any) -> None:
stream_config: Optional[dict[str, Any]] = None) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
self.cache_limit = cache_limit
Expand Down Expand Up @@ -444,16 +444,16 @@ def __init__(self,
for stream in streams:
stream.apply_default(default)
else:
kwargs = {
stream_config = stream_config or {}
stream_config.update({
'remote': remote,
'local': local,
'split': split,
'download_retry': download_retry,
'download_timeout': download_timeout,
'validate_hash': validate_hash,
'keep_zip': keep_zip,
**kwargs,
}
})

# Construct a Stream instance using registry-based construction
default = construct_from_registry(
Expand All @@ -462,7 +462,7 @@ def __init__(self,
partial_function=False,
pre_validation_function=None,
post_validation_function=None,
kwargs=kwargs,
kwargs=stream_config,
)

streams = [default]
Expand Down
12 changes: 8 additions & 4 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,10 +1098,14 @@ def __init__(

streams_registry.register('custom_stream', func=CustomStream)

dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
stream_name='custom_stream',
custom_arg=100)
dataset = StreamingDataset(
local=local_dir,
remote=remote_dir,
stream_name='custom_stream',
stream_config={
'custom_arg': 100,
},
)

assert len(dataset.streams) == 1
assert isinstance(dataset.streams[0], CustomStream)
Expand Down

0 comments on commit 52de4dc

Please sign in to comment.