Skip to content

Commit

Permalink
Fix doc strings (#469)
Browse files Browse the repository at this point in the history
* Fix doc strings

* Add util to doc

* update warning msg

* Add ordering for local_list_dir
  • Loading branch information
XiaohanZhangCMU authored Oct 13, 2023
1 parent baad3d9 commit 36ba4ce
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def _modules_to_rst() -> List[types.ModuleType]:
streaming.base.shared,
streaming.base.shuffle,
streaming.base.storage,
streaming.base.util,
streaming.base.world,
]
exclude_modules: List[types.Module] = [streaming.base, streaming._version]
Expand Down
2 changes: 1 addition & 1 deletion streaming/base/storage/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]:
if prefix is None:
prefix = ''
file_paths = []
for dirpath, _, files in os.walk(os.path.join(self.local, prefix)):
for dirpath, _, files in sorted(os.walk(os.path.join(self.local, prefix))):
for file in files:
file_paths.append(os.path.join(dirpath, file))
return file_paths
74 changes: 54 additions & 20 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

__all__ = [
'get_list_arg', 'wait_for_file_to_exist', 'bytes_to_int', 'number_abbrev_to_int',
'clean_stale_shared_memory', 'get_import_exception_message', 'retry'
'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'retry'
]


Expand Down Expand Up @@ -216,11 +216,38 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str:


def merge_index(*args: Any, **kwargs: Any):
"""Redirect to one of two merge_index functions based on arguments."""
r"""Merge index.json from partitions to form a global index.json.
This can be called as
merge_index(index_file_urls, out, keep_local, download_timeout)
merge_index(out, keep_local, download_timeout)
The first signature takes in a list of index files URLs of MDS partitions.
The second takes the root of a MDS dataset and parse the partition folders from there.
Args:
index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions.
Each element can take the form of a single path string or a tuple string.
1. If ``index_file_urls`` is a List of local URLs, merge locally without download.
2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging.
3. If ``index_file_urls`` is a List of remote URLs, download all and merge.
out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file
1. A local directory, merge index happens locally.
2. A remote directory, download all the sub-directories index.json, merge locally and upload.
3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not.
keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``.
download_timeout (int): The allowed time for downloading each json file. Defaults to 60.
"""
if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]:
return _merge_index_from_list(*args, **kwargs)
elif (isinstance(args[0], str) or
isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2]:
isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2, 3]:
return _merge_index_from_root(*args, **kwargs)
raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}')

Expand All @@ -243,13 +270,14 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],
out (Union[str, Tuple[str, str]]): path to put the merged index file
keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``
download_timeout (int): The allowed time for downloading each json file. Defaults to 60s.
download_timeout (int): The allowed time for downloading each json file. Defaults to 60.
"""
from streaming.base.storage.download import download_file
from streaming.base.storage.upload import CloudUploader

if not index_file_urls or not out:
logger.warning('Need to specify both `index_file_urls` and `out`. No index merged')
logger.warning('Either index_file_urls or out are None. ' +
'Need to specify both `index_file_urls` and `out`. ' + 'No index merged')
return

# This is the index json file name, e.g., it is index.json as of 0.6.0
Expand Down Expand Up @@ -330,7 +358,9 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],
shutil.rmtree(cu.local, ignore_errors=True)


def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = True) -> None:
def _merge_index_from_root(out: Union[str, Tuple[str, str]],
keep_local: bool = True,
download_timeout: int = 60) -> None:
"""Merge index.json given the root of MDS dataset. Write merged index to the root folder.
Args:
Expand All @@ -343,6 +373,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool =
If not, download all the sub-directories index.json from remote to local,
merge locally, and upload to remote_dir .
keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``
download_timeout (int): The allowed time for downloading each json file. Defaults to 60.
"""
from streaming.base.storage.upload import CloudUploader

Expand Down Expand Up @@ -386,15 +417,18 @@ def not_merged_index(index_file_path: str, out: str):
_merge_index_from_list(list(zip(local_index_files, remote_index_files)),
out,
keep_local=keep_local,
download_timeout=60)
download_timeout=download_timeout)
else:
_merge_index_from_list(remote_index_files,
out,
keep_local=keep_local,
download_timeout=60)
download_timeout=download_timeout)
return

_merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60)
_merge_index_from_list(local_index_files,
out,
keep_local=keep_local,
download_timeout=download_timeout)


@overload
Expand Down Expand Up @@ -427,21 +461,21 @@ def retry( # type: ignore
``initial_backoff * 2**num_attempts + random.random() * max_jitter`` seconds.
Example:
.. testcode::
.. testcode::
from streaming.base.util import retry
from streaming.base.util import retry
num_tries = 0
num_tries = 0
@retry(RuntimeError, num_attempts=3, initial_backoff=0.1)
def flaky_function():
global num_tries
if num_tries < 2:
num_tries += 1
raise RuntimeError("Called too soon!")
return "Third time's a charm."
@retry(RuntimeError, num_attempts=3, initial_backoff=0.1)
def flaky_function():
global num_tries
if num_tries < 2:
num_tries += 1
raise RuntimeError("Called too soon!")
return "Third time's a charm."
print(flaky_function())
print(flaky_function())
.. testoutput::
Expand Down

0 comments on commit 36ba4ce

Please sign in to comment.