From 859f90862aab98c38be4a3872962558cdebd75b3 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 12:26:22 -0700 Subject: [PATCH 01/59] First commit --- streaming/base/storage/upload.py | 67 ++++++++++------ streaming/base/util.py | 132 ++++++++++++++++++++++++++++++- tests/test_util.py | 85 +++++++++++++++++++- 3 files changed, 260 insertions(+), 24 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index a41626512..22a8abff2 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -56,7 +56,9 @@ class CloudUploader: def get(cls, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> Any: + progress_bar: bool = False, + exist_ok: bool = False, + ) -> Any: """Instantiate a cloud provider uploader or a local uploader based on remote path. Args: @@ -72,6 +74,7 @@ def get(cls, shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. @@ -85,7 +88,7 @@ def get(cls, if prefix == 'dbfs:/Volumes': provider_prefix = prefix return getattr(sys.modules[__name__], UPLOADERS[provider_prefix])(out, keep_local, - progress_bar) + progress_bar, exist_ok) def _validate(self, out: Union[str, Tuple[str, str]]) -> None: """Validate the `out` argument. @@ -118,7 +121,8 @@ def _validate(self, out: Union[str, Tuple[str, str]]) -> None: def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: + progress_bar: bool = False, + exist_ok: bool = False) -> None: """Initialize and validate local and remote path. Args: @@ -134,6 +138,7 @@ def __init__(self, shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -155,7 +160,7 @@ def __init__(self, self.local = out[0] self.remote = out[1] - if os.path.exists(self.local) and len(os.listdir(self.local)) != 0: + if not exist_ok and os.path.exists(self.local) and len(os.listdir(self.local)) != 0: raise FileExistsError(f'Directory is not empty: {self.local}') os.makedirs(self.local, exist_ok=True) @@ -196,13 +201,15 @@ class S3Uploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) import boto3 from botocore.config import Config @@ -277,13 +284,15 @@ class GCSUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: import boto3 @@ -386,13 +395,15 @@ class OCIUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) import oci @@ -469,13 +480,15 @@ class AzureUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) from azure.storage.blob import BlobServiceClient @@ -547,13 +560,15 @@ class AzureDataLakeUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) from azure.storage.filedatalake import DataLakeServiceClient @@ -622,13 +637,15 @@ class DatabricksUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) self.client = self._create_workspace_client() def _create_workspace_client(self): @@ -656,13 +673,15 @@ class DatabricksUnityCatalogUploader(DatabricksUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) def upload_file(self, filename: str): """Upload file from local instance to Databricks Unity Catalog. @@ -695,13 +714,15 @@ class DBFSUploader(DatabricksUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore self.check_folder_exists() @@ -755,13 +776,15 @@ class LocalUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, - progress_bar: bool = False) -> None: - super().__init__(out, keep_local, progress_bar) + progress_bar: bool = False, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, exist_ok) # Create remote directory if it doesn't exist if self.remote: os.makedirs(self.remote, exist_ok=True) diff --git a/streaming/base/util.py b/streaming/base/util.py index 8fc988f69..1c5f42cc1 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -3,16 +3,23 @@ """Utility and helper functions for datasets.""" +import json +import logging import os from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from time import sleep, time -from typing import List, Union +from typing import List, Union, Tuple import torch.distributed as dist from streaming.base.constant import SHM_TO_CLEAN from streaming.base.distributed import get_local_rank, maybe_init_dist from streaming.base.shared.prefix import _get_path +from streaming.base.format.index import get_index_basename +import urllib.parse +from tempfile import mkdtemp +import shutil +logger = logging.getLogger(__name__) __all__ = ['get_list_arg'] @@ -196,3 +203,126 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: return f'Streaming was installed without {package_name} support. ' + \ f'To use {package_name} related packages with Streaming, run ' + \ f'`pip install \'mosaicml-streaming[{package_name}]\'`.' + + +def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], + out: Union[str, Tuple[str, str]], + *, + keep_local: bool = True, + overwrite: bool = True, + download_timeout: int = 100) -> None: + """Merge index.json from a list of remote or local directories into one for streaming and save the merged index to local or remote. + + Args: + folder_urls (Iterable): folders that contain index.json for the partition + each element can take the form of a single path string or a tuple string + + for each url in folder_urls, if url is + 1. tuple (local, remote): check if local is accessible. + -> Yes: use local index to merge + -> No: download from remote first, then merge + 2. str (local path): use local path to merge. + raise FileNotFoundError if any local index is not accessible + 3. str (remote url): download to a temp directory first, then merge + + out (Union[str, Tuple[str, str]]): path to put the merged index file + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + overwrite (bool): Overwrite merged index file in out if there exists one.Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file + defaults to 60, same as streaming.download_file + """ + # Import here to avoid circular import error + from streaming.base.storage.upload import CloudUploader + from streaming.base.storage.download import download_file + + if not folder_urls: + logger.warning('No partitions exist, no index merged') + return + + # This is the index json file name, e.g., it is index.json as of 0.6.0 + index_basename = get_index_basename() + + if os.path.exists(os.path.join(out, index_basename)) and overwrite: + logger.warning('Merged index already exists. no index merged if overwrite=False') + return + + # Prepare a temp folder to download index.json rom remote if necessary. Removed in the end. + temp_root = mkdtemp() + + # Remove '/' from right, so os.path.basename gives relative path to each folder + urls = [] + for url in folder_urls: + if type(url) is str: + urls.append(url.rstrip('/').strip()) + else: + urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) + + # Determine if we need to call download_file. + download = False + for url in urls: + local = remote = url + if type(url) is tuple: + # If driver cannot access the local path, download = True + download = not os.path.exists(url[0]) + else: + # If url is a remote, download = True, False otherwise + download = urllib.parse.urlparse(url).scheme is not None + + # As long as one index file needs download, we download them all to keep it simple + if download: + break + + # container for absolute local folder path + partitions = [] + for url in urls: + local = remote = url + + if download: + # If download is needed, download url from remote to temp_root + path = urllib.parse.urlparse(remote).path + local = os.path.join(temp_root, path.lstrip('/')) + try: + remote_url = os.path.join(remote, index_basename) + local_path = os.path.join(local, index_basename) + download_file(remote_url, local_path, download_timeout) + except Exception as ex: + raise RuntimeError(f'failed to download index.json {remote_url} -->{local_path}') from ex + + assert(os.path.exists(local)), f"Folder {local} does not exit or cannot be acceessed by the current process" + partitions.append(local) + + # merge index files into shards + shards = [] + for partition in partitions: + partition_index = f'{partition}/{index_basename}' + mds_partition_basename = os.path.basename(partition) + obj = json.load(open(partition_index)) + for i in range(len(obj['shards'])): + shard = obj['shards'][i] + for key in ('raw_data', 'zip_data'): + if shard.get(key): + basename = shard[key]['basename'] + obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename, basename) + shards += obj['shards'] + + # Save merged index locally + obj = { + 'version': 2, + 'shards': shards, + } + merged_index_path = os.path.join(temp_root, index_basename) + with open(merged_index_path, 'w') as outfile: + json.dump(obj, outfile) + + # Upload merged index to remote if out has remote part + # Otherwise, move it from temp root to out location + cu = CloudUploader.get(out, keep_local = True, exist_ok = True) + shutil.move(merged_index_path, cu.local) + if cu.remote is not None: + cu.upload_file(index_basename) + + # Clean up + shutil.rmtree(temp_root, ignore_errors=True) + if not keep_local: + shutil.rmtree(cu.local, ignore_errors=True) + diff --git a/tests/test_util.py b/tests/test_util.py index e47a27c52..2308db7a5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional +from typing import List, Optional, Tuple, Dict, Any import pytest @@ -10,7 +10,16 @@ from streaming.base.shared.prefix import _get_path from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, number_abbrev_to_int) +from tests.common.utils import convert_to_mds +from pyspark.sql import SparkSession +from pyspark.sql.functions import col +from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + +from streaming.base.converters import dataframeToMDS +from streaming.base.util import merge_index +import os +import glob @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) @@ -105,3 +114,77 @@ def test_clean_stale_shared_memory(): # If clean up is successful, it should raise FileNotFoundError Exception with pytest.raises(FileNotFoundError): _ = BuiltinSharedMemory(name, False, 64) + + +@pytest.fixture +def dataframe(): + spark = SparkSession.builder.getOrCreate() # pyright: ignore + + data = [('36636', 'Finance', (3000, 'USA')), ('40288', 'Finance', (5000, 'IND')), + ('42114', 'Sales', (3900, 'USA')), ('39192', 'Marketing', (2500, 'CAN')), + ('34534', 'Sales', (6500, 'USA'))] + schema = StructType([ + StructField('id', StringType(), True), + StructField('dept', StringType(), True), + StructField( + 'properties', + StructType([ + StructField('salary', IntegerType(), True), + StructField('location', StringType(), True) + ])) + ]) + + df = spark.createDataFrame(data=data, schema=schema).repartition(3) + yield df + +"""Ideally, we want to test all these combinations +@pytest.mark.parametrize('folder_urls', { + ('/Volumes/mdsdata/00/', + '/Volumes/mdsdata/01/', + '/Volumes/mdsdata/02/' + ): 'driver can access the folder_urls, no download, merge diretly', + ('gs://mybucket/mdsdata/00/', + 'gs://mybucket/mdsdata/01/', + 'gs://mybucket/mdsdata/02/' + ): 'driver downloads from remote bucket, merge locally', + (('/tmp/zscf/mdsdata/00', 'gs://mybucket/mdsdata/00/'), + ('/tmp/zscf/mdsdata/01', 'gs://mybucket/mdsdata/01/'), + ('/tmp/zscf/mdsdata/02', 'gs://mybucket/mdsdata/02/') + ): 'driver cannot access local directories, so download from remote', + (('/Volumes/mdsdata/00', 'gs://mybucket/mdsdata/00/'), + ('/Volumes/mdsdata/01', 'gs://mybucket/mdsdata/01/'), + ('/Volumes/mdsdata/02', 'gs://mybucket/mdsdata/02/') + ): 'driver can access local directories, use local to merge', + }) +@pytest.mark.parametrize('out', { + '/Volumes/mdsdata/' : 'save merged index to /Volumes/mdsdata/', + 'gs://mybucket/mdsdata/' : 'save merged index to gs://mybucket/mdsdata. remove local copy if keep_local = False', + ('/Volumes/mdsdata/', 'gs://mybucket/mdsdata'): 'save merged index to both /Volumes/mdsdata and gs://mybucket/dsdata/, remove local copy if keep_local=False', + }) +""" +@pytest.mark.usefixtures('local_remote_dir') +@pytest.mark.parametrize('keep_local', [True, False]) +@pytest.mark.parametrize('overwrite', [True, False]) +def test_merge_index(local_remote_dir: Tuple[str, str], + dataframe: Any, + keep_local: bool, + #folder_urls: Dict, + #out: Dict, + overwrite: bool): + """ Test the following scenarios for merge_index + """ + + local, remote = local_remote_dir + local = '/tmp/mdsdata' + mds_kwargs = { + 'out': local, + 'keep_local': True, + } + _, _ = dataframeToMDS(dataframe.select(col('id'), col('dept')), + merge_index=False, + mds_kwargs=mds_kwargs) + + folder_urls = glob.glob(local + '/*') + merge_index(folder_urls, local) + + assert(os.path.exists(os.path.join(local, 'index.json'))) From 45751749231f03bab5a5f0b053ab959539f3da07 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 13:43:14 -0700 Subject: [PATCH 02/59] Add a naive mds datasts --- streaming/base/util.py | 122 +++++++++--------- .../resources/naive_MDSdataset/25/index.json | 1 + .../naive_MDSdataset/25/shard.00000.mds | Bin 0 -> 244 bytes .../resources/naive_MDSdataset/26/index.json | 1 + .../naive_MDSdataset/26/shard.00000.mds | Bin 0 -> 266 bytes .../resources/naive_MDSdataset/27/index.json | 1 + tests/resources/naive_MDSdataset/index.json | 1 + tests/test_util.py | 97 +++++++------- 8 files changed, 120 insertions(+), 103 deletions(-) create mode 100644 tests/resources/naive_MDSdataset/25/index.json create mode 100644 tests/resources/naive_MDSdataset/25/shard.00000.mds create mode 100644 tests/resources/naive_MDSdataset/26/index.json create mode 100644 tests/resources/naive_MDSdataset/26/shard.00000.mds create mode 100644 tests/resources/naive_MDSdataset/27/index.json create mode 100644 tests/resources/naive_MDSdataset/index.json diff --git a/streaming/base/util.py b/streaming/base/util.py index 1c5f42cc1..d1d6a3a0f 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -17,7 +17,7 @@ from streaming.base.shared.prefix import _get_path from streaming.base.format.index import get_index_basename import urllib.parse -from tempfile import mkdtemp +import tempfile import shutil logger = logging.getLogger(__name__) @@ -246,9 +246,6 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], logger.warning('Merged index already exists. no index merged if overwrite=False') return - # Prepare a temp folder to download index.json rom remote if necessary. Removed in the end. - temp_root = mkdtemp() - # Remove '/' from right, so os.path.basename gives relative path to each folder urls = [] for url in folder_urls: @@ -266,63 +263,72 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], download = not os.path.exists(url[0]) else: # If url is a remote, download = True, False otherwise - download = urllib.parse.urlparse(url).scheme is not None + download = urllib.parse.urlparse(url).scheme != '' # As long as one index file needs download, we download them all to keep it simple if download: break - # container for absolute local folder path - partitions = [] - for url in urls: - local = remote = url - - if download: - # If download is needed, download url from remote to temp_root - path = urllib.parse.urlparse(remote).path - local = os.path.join(temp_root, path.lstrip('/')) - try: - remote_url = os.path.join(remote, index_basename) - local_path = os.path.join(local, index_basename) - download_file(remote_url, local_path, download_timeout) - except Exception as ex: - raise RuntimeError(f'failed to download index.json {remote_url} -->{local_path}') from ex - - assert(os.path.exists(local)), f"Folder {local} does not exit or cannot be acceessed by the current process" - partitions.append(local) - - # merge index files into shards - shards = [] - for partition in partitions: - partition_index = f'{partition}/{index_basename}' - mds_partition_basename = os.path.basename(partition) - obj = json.load(open(partition_index)) - for i in range(len(obj['shards'])): - shard = obj['shards'][i] - for key in ('raw_data', 'zip_data'): - if shard.get(key): - basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename, basename) - shards += obj['shards'] - - # Save merged index locally - obj = { - 'version': 2, - 'shards': shards, - } - merged_index_path = os.path.join(temp_root, index_basename) - with open(merged_index_path, 'w') as outfile: - json.dump(obj, outfile) - - # Upload merged index to remote if out has remote part - # Otherwise, move it from temp root to out location - cu = CloudUploader.get(out, keep_local = True, exist_ok = True) - shutil.move(merged_index_path, cu.local) - if cu.remote is not None: - cu.upload_file(index_basename) - - # Clean up - shutil.rmtree(temp_root, ignore_errors=True) - if not keep_local: - shutil.rmtree(cu.local, ignore_errors=True) + print('download = ', download) + # Prepare a temp folder to download index.json rom remote if necessary. Removed in the end. + with tempfile.TemporaryDirectory() as temp_root: + + # container for absolute local folder path + partitions = [] + n_downloads = 0 + for url in urls: + local = remote = url + + if download: + # If download is needed, download url from remote to temp_root + path = urllib.parse.urlparse(remote).path + local = os.path.join(temp_root, path.lstrip('/')) + try: + remote_url = os.path.join(remote, index_basename) + local_path = os.path.join(local, index_basename) + download_file(remote_url, local_path, download_timeout) + n_downloads += 1 + except Exception as ex: + raise RuntimeError(f'failed to download index.json {remote_url} -->{local_path}') from ex + + if not (os.path.exists(local)): + raise FileNotFoundError("Folder {local} does not exit or cannot be acceessed by the current process") + partitions.append(local) + + # merge index files into shards + shards = [] + for partition in partitions: + partition_index = f'{partition}/{index_basename}' + mds_partition_basename = os.path.basename(partition) + obj = json.load(open(partition_index)) + for i in range(len(obj['shards'])): + shard = obj['shards'][i] + for key in ('raw_data', 'zip_data'): + if shard.get(key): + basename = shard[key]['basename'] + obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename, basename) + shards += obj['shards'] + + # Save merged index locally + obj = { + 'version': 2, + 'shards': shards, + } + merged_index_path = os.path.join(temp_root, index_basename) + with open(merged_index_path, 'w') as outfile: + json.dump(obj, outfile) + + # Upload merged index to remote if out has remote part + # Otherwise, move it from temp root to out location + cu = CloudUploader.get(out, keep_local = True, exist_ok = True) + shutil.move(merged_index_path, cu.local) + if cu.remote is not None: + cu.upload_file(index_basename) + + # Clean up + # shutil.rmtree(temp_root, ignore_errors=True) + if not keep_local: + shutil.rmtree(cu.local, ignore_errors=True) + + return n_downloads diff --git a/tests/resources/naive_MDSdataset/25/index.json b/tests/resources/naive_MDSdataset/25/index.json new file mode 100644 index 000000000..c893ad07e --- /dev/null +++ b/tests/resources/naive_MDSdataset/25/index.json @@ -0,0 +1 @@ +{"shards": [{"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "shard.00000.mds", "bytes": 244, "hashes": {}}, "samples": 2, "size_limit": 67108864, "version": 2, "zip_data": null}], "version": 2} \ No newline at end of file diff --git a/tests/resources/naive_MDSdataset/25/shard.00000.mds b/tests/resources/naive_MDSdataset/25/shard.00000.mds new file mode 100644 index 0000000000000000000000000000000000000000..c776992af71090060606221dac217c624c93070d GIT binary patch literal 244 zcmYk0Jr2S!423Hs25wPh?m&MiEq7pIM5v0IKqK-~#VG@b8*&QHzyYvH2Zqm<{hpt^ zNRs4*ypa#`V7=3mv7NNN6UttI?b0KI;8~Xb+6nvYvE0b03poZdD8c@8Q1__YN$V`7 z8dWoT380+C@Tjq~^M(hUnGrxy1BW4A(+x#+S{X%_dYiACrmk>*lYY)Ao-6!+iR`(* V%7DL@ZQd5NAr4$iD636beF0ktMydb+ literal 0 HcmV?d00001 diff --git a/tests/resources/naive_MDSdataset/26/index.json b/tests/resources/naive_MDSdataset/26/index.json new file mode 100644 index 000000000..869b225f4 --- /dev/null +++ b/tests/resources/naive_MDSdataset/26/index.json @@ -0,0 +1 @@ +{"shards": [{"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "shard.00000.mds", "bytes": 266, "hashes": {}}, "samples": 3, "size_limit": 67108864, "version": 2, "zip_data": null}], "version": 2} \ No newline at end of file diff --git a/tests/resources/naive_MDSdataset/26/shard.00000.mds b/tests/resources/naive_MDSdataset/26/shard.00000.mds new file mode 100644 index 0000000000000000000000000000000000000000..42d6c202b350ecba24425d681efdcb36d1518b2b GIT binary patch literal 266 zcmZ9HO$x#=5QQTmf~P21chNs$@Bm)GrAP@)w1H$&lGcUd9X*7HaOuQ!AuhfdX5O0z zvm}H(kr(ntHVZ=Tv~y;%&?@Mh)Nl!OmmO&sJP6_&b-amDt d6|&bLpi?ztHT&B&Ma6maL=M8J&{SpFd;oURNsj;k literal 0 HcmV?d00001 diff --git a/tests/resources/naive_MDSdataset/27/index.json b/tests/resources/naive_MDSdataset/27/index.json new file mode 100644 index 000000000..0abdc1f33 --- /dev/null +++ b/tests/resources/naive_MDSdataset/27/index.json @@ -0,0 +1 @@ +{"shards": [], "version": 2} \ No newline at end of file diff --git a/tests/resources/naive_MDSdataset/index.json b/tests/resources/naive_MDSdataset/index.json new file mode 100644 index 000000000..a64978381 --- /dev/null +++ b/tests/resources/naive_MDSdataset/index.json @@ -0,0 +1 @@ +{"version": 2, "shards": [{"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "26/shard.00000.mds", "bytes": 266, "hashes": {}}, "samples": 3, "size_limit": 67108864, "version": 2, "zip_data": null}, {"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "25/shard.00000.mds", "bytes": 244, "hashes": {}}, "samples": 2, "size_limit": 67108864, "version": 2, "zip_data": null}]} \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index 2308db7a5..4eec145bc 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -18,6 +18,7 @@ from streaming.base.converters import dataframeToMDS from streaming.base.util import merge_index +import tempfile import os import glob @@ -137,54 +138,60 @@ def dataframe(): df = spark.createDataFrame(data=data, schema=schema).repartition(3) yield df -"""Ideally, we want to test all these combinations -@pytest.mark.parametrize('folder_urls', { - ('/Volumes/mdsdata/00/', - '/Volumes/mdsdata/01/', - '/Volumes/mdsdata/02/' - ): 'driver can access the folder_urls, no download, merge diretly', - ('gs://mybucket/mdsdata/00/', - 'gs://mybucket/mdsdata/01/', - 'gs://mybucket/mdsdata/02/' - ): 'driver downloads from remote bucket, merge locally', - (('/tmp/zscf/mdsdata/00', 'gs://mybucket/mdsdata/00/'), - ('/tmp/zscf/mdsdata/01', 'gs://mybucket/mdsdata/01/'), - ('/tmp/zscf/mdsdata/02', 'gs://mybucket/mdsdata/02/') - ): 'driver cannot access local directories, so download from remote', - (('/Volumes/mdsdata/00', 'gs://mybucket/mdsdata/00/'), - ('/Volumes/mdsdata/01', 'gs://mybucket/mdsdata/01/'), - ('/Volumes/mdsdata/02', 'gs://mybucket/mdsdata/02/') - ): 'driver can access local directories, use local to merge', - }) -@pytest.mark.parametrize('out', { - '/Volumes/mdsdata/' : 'save merged index to /Volumes/mdsdata/', - 'gs://mybucket/mdsdata/' : 'save merged index to gs://mybucket/mdsdata. remove local copy if keep_local = False', - ('/Volumes/mdsdata/', 'gs://mybucket/mdsdata'): 'save merged index to both /Volumes/mdsdata and gs://mybucket/dsdata/, remove local copy if keep_local=False', - }) +"""Example input urls to test + ['gs://mybucket/mdsdata/25/'...] + ['/path/never/exists/25',... ] + [('/path/never/exists/25', 'gs://mybucket/mdsdata/25/'), ...] + [('tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), ...] """ +@pytest.mark.parametrize('folder_urls', ['local_accessible', 'remote', 'local_unaccessible', 'local_accessible_tuple']) +@pytest.mark.parametrize('out', ['local_str', 'remote_str', 'tuple']) @pytest.mark.usefixtures('local_remote_dir') @pytest.mark.parametrize('keep_local', [True, False]) -@pytest.mark.parametrize('overwrite', [True, False]) def test_merge_index(local_remote_dir: Tuple[str, str], dataframe: Any, keep_local: bool, - #folder_urls: Dict, - #out: Dict, - overwrite: bool): - """ Test the following scenarios for merge_index - """ - - local, remote = local_remote_dir - local = '/tmp/mdsdata' - mds_kwargs = { - 'out': local, - 'keep_local': True, - } - _, _ = dataframeToMDS(dataframe.select(col('id'), col('dept')), - merge_index=False, - mds_kwargs=mds_kwargs) - - folder_urls = glob.glob(local + '/*') - merge_index(folder_urls, local) - - assert(os.path.exists(os.path.join(local, 'index.json'))) + folder_urls: Dict, + out: Dict): + + naive_mds_partitions= ['tests/resources/naive_MDSdataset/25/', + 'tests/resources/naive_MDSdataset/26/', + 'tests/resources/naive_MDSdataset/27/'] + + if folder_urls == 'local_accessible': + folder_urls = [ os.getcwd() + '/' + s for s in naive_mds_partitions] + print(folder_urls) + + if out == 'local_str': + with tempfile.TemporaryDirectory() as tmp: + n_downloads = merge_index(folder_urls, tmp, keep_local = keep_local, overwrite = True) + if keep_local: + assert(os.path.exists(os.path.join(tmp, 'index.json'))) + else: + assert(not os.path.exists(os.path.join(tmp, 'index.json'))) + assert n_downloads == 0, f"n_downloads should be 0 instead of {n_downloads}" + else: + return + + if folder_urls == 'remote': + return + + if folder_urls == 'local_unaccessible': + with tempfile.TemporaryDirectory() as tmp_data_root: + folder_urls = [ tmp_data_root + '/' + s for s in naive_mds_partitions] + with pytest.raises(FileNotFoundError, match=f'.* does not exit or cannot be acceessed by the current process.*'): + merge_index(folder_urls, tmp_data_root, keep_local = keep_local, overwrite = True) + + if folder_urls == 'local_accessible_tuple': + folder_urls = [] + for s in naive_mds_partitions: + folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/'+ s )) + if out == 'local_str': + with tempfile.TemporaryDirectory() as tmp_data_root: + folder_urls = [ tmp_data_root + '/' + s for s in naive_mds_partitions] + with pytest.raises(FileNotFoundError, match=f'.* does not exit or cannot be acceessed by the current process.*'): + merge_index(folder_urls, tmp_data_root, keep_local = keep_local, overwrite = True) + else: + return + + From 6b3164077c89c1e0499994b1bf8ceacaa2d3df20 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 13:54:20 -0700 Subject: [PATCH 03/59] fix lints --- streaming/base/storage/upload.py | 13 +-- streaming/base/util.py | 39 ++++---- .../resources/naive_MDSdataset/25/index.json | 32 +++++- .../resources/naive_MDSdataset/26/index.json | 32 +++++- .../resources/naive_MDSdataset/27/index.json | 5 +- tests/resources/naive_MDSdataset/index.json | 58 ++++++++++- tests/test_util.py | 98 +++++++------------ 7 files changed, 187 insertions(+), 90 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 22a8abff2..47d794fd4 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -53,12 +53,13 @@ class CloudUploader: """Upload local files to a cloud storage.""" @classmethod - def get(cls, - out: Union[str, Tuple[str, str]], - keep_local: bool = False, - progress_bar: bool = False, - exist_ok: bool = False, - ) -> Any: + def get( + cls, + out: Union[str, Tuple[str, str]], + keep_local: bool = False, + progress_bar: bool = False, + exist_ok: bool = False, + ) -> Any: """Instantiate a cloud provider uploader or a local uploader based on remote path. Args: diff --git a/streaming/base/util.py b/streaming/base/util.py index d1d6a3a0f..e4510451b 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -6,19 +6,20 @@ import json import logging import os +import shutil +import tempfile +import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from time import sleep, time -from typing import List, Union, Tuple +from typing import List, Tuple, Union import torch.distributed as dist from streaming.base.constant import SHM_TO_CLEAN from streaming.base.distributed import get_local_rank, maybe_init_dist -from streaming.base.shared.prefix import _get_path from streaming.base.format.index import get_index_basename -import urllib.parse -import tempfile -import shutil +from streaming.base.shared.prefix import _get_path + logger = logging.getLogger(__name__) __all__ = ['get_list_arg'] @@ -210,8 +211,8 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], *, keep_local: bool = True, overwrite: bool = True, - download_timeout: int = 100) -> None: - """Merge index.json from a list of remote or local directories into one for streaming and save the merged index to local or remote. + download_timeout: int = 100) -> int: + """Merge index.json from a list of remote or local directories. Args: folder_urls (Iterable): folders that contain index.json for the partition @@ -230,21 +231,25 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], overwrite (bool): Overwrite merged index file in out if there exists one.Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file defaults to 60, same as streaming.download_file + + Returns: + int: count of files downloaded during function call """ # Import here to avoid circular import error - from streaming.base.storage.upload import CloudUploader from streaming.base.storage.download import download_file + from streaming.base.storage.upload import CloudUploader if not folder_urls: logger.warning('No partitions exist, no index merged') - return + return 0 # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() - if os.path.exists(os.path.join(out, index_basename)) and overwrite: - logger.warning('Merged index already exists. no index merged if overwrite=False') - return + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + if os.path.exists(os.path.join(cu.local, index_basename)) and overwrite: + logger.warning('Merged index already exists locally. no index merged if overwrite=False') + return 0 # Remove '/' from right, so os.path.basename gives relative path to each folder urls = [] @@ -289,10 +294,11 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], download_file(remote_url, local_path, download_timeout) n_downloads += 1 except Exception as ex: - raise RuntimeError(f'failed to download index.json {remote_url} -->{local_path}') from ex + raise RuntimeError(f'failed to download index.json {url}') from ex if not (os.path.exists(local)): - raise FileNotFoundError("Folder {local} does not exit or cannot be acceessed by the current process") + raise FileNotFoundError( + 'Folder {local} does not exit or cannot be acceessed by the current process') partitions.append(local) # merge index files into shards @@ -306,7 +312,8 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], for key in ('raw_data', 'zip_data'): if shard.get(key): basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename, basename) + obj['shards'][i][key]['basename'] = os.path.join( + mds_partition_basename, basename) shards += obj['shards'] # Save merged index locally @@ -320,7 +327,6 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], # Upload merged index to remote if out has remote part # Otherwise, move it from temp root to out location - cu = CloudUploader.get(out, keep_local = True, exist_ok = True) shutil.move(merged_index_path, cu.local) if cu.remote is not None: cu.upload_file(index_basename) @@ -331,4 +337,3 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], shutil.rmtree(cu.local, ignore_errors=True) return n_downloads - diff --git a/tests/resources/naive_MDSdataset/25/index.json b/tests/resources/naive_MDSdataset/25/index.json index c893ad07e..b7c1f591f 100644 --- a/tests/resources/naive_MDSdataset/25/index.json +++ b/tests/resources/naive_MDSdataset/25/index.json @@ -1 +1,31 @@ -{"shards": [{"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "shard.00000.mds", "bytes": 244, "hashes": {}}, "samples": 2, "size_limit": 67108864, "version": 2, "zip_data": null}], "version": 2} \ No newline at end of file +{ + "shards": [ + { + "column_encodings": [ + "str", + "str" + ], + "column_names": [ + "dept", + "id" + ], + "column_sizes": [ + null, + null + ], + "compression": null, + "format": "mds", + "hashes": [], + "raw_data": { + "basename": "shard.00000.mds", + "bytes": 244, + "hashes": {} + }, + "samples": 2, + "size_limit": 67108864, + "version": 2, + "zip_data": null + } + ], + "version": 2 +} diff --git a/tests/resources/naive_MDSdataset/26/index.json b/tests/resources/naive_MDSdataset/26/index.json index 869b225f4..7aac4a36b 100644 --- a/tests/resources/naive_MDSdataset/26/index.json +++ b/tests/resources/naive_MDSdataset/26/index.json @@ -1 +1,31 @@ -{"shards": [{"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "shard.00000.mds", "bytes": 266, "hashes": {}}, "samples": 3, "size_limit": 67108864, "version": 2, "zip_data": null}], "version": 2} \ No newline at end of file +{ + "shards": [ + { + "column_encodings": [ + "str", + "str" + ], + "column_names": [ + "dept", + "id" + ], + "column_sizes": [ + null, + null + ], + "compression": null, + "format": "mds", + "hashes": [], + "raw_data": { + "basename": "shard.00000.mds", + "bytes": 266, + "hashes": {} + }, + "samples": 3, + "size_limit": 67108864, + "version": 2, + "zip_data": null + } + ], + "version": 2 +} diff --git a/tests/resources/naive_MDSdataset/27/index.json b/tests/resources/naive_MDSdataset/27/index.json index 0abdc1f33..16d98001e 100644 --- a/tests/resources/naive_MDSdataset/27/index.json +++ b/tests/resources/naive_MDSdataset/27/index.json @@ -1 +1,4 @@ -{"shards": [], "version": 2} \ No newline at end of file +{ + "shards": [], + "version": 2 +} diff --git a/tests/resources/naive_MDSdataset/index.json b/tests/resources/naive_MDSdataset/index.json index a64978381..2915cd61b 100644 --- a/tests/resources/naive_MDSdataset/index.json +++ b/tests/resources/naive_MDSdataset/index.json @@ -1 +1,57 @@ -{"version": 2, "shards": [{"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "26/shard.00000.mds", "bytes": 266, "hashes": {}}, "samples": 3, "size_limit": 67108864, "version": 2, "zip_data": null}, {"column_encodings": ["str", "str"], "column_names": ["dept", "id"], "column_sizes": [null, null], "compression": null, "format": "mds", "hashes": [], "raw_data": {"basename": "25/shard.00000.mds", "bytes": 244, "hashes": {}}, "samples": 2, "size_limit": 67108864, "version": 2, "zip_data": null}]} \ No newline at end of file +{ + "version": 2, + "shards": [ + { + "column_encodings": [ + "str", + "str" + ], + "column_names": [ + "dept", + "id" + ], + "column_sizes": [ + null, + null + ], + "compression": null, + "format": "mds", + "hashes": [], + "raw_data": { + "basename": "26/shard.00000.mds", + "bytes": 266, + "hashes": {} + }, + "samples": 3, + "size_limit": 67108864, + "version": 2, + "zip_data": null + }, + { + "column_encodings": [ + "str", + "str" + ], + "column_names": [ + "dept", + "id" + ], + "column_sizes": [ + null, + null + ], + "compression": null, + "format": "mds", + "hashes": [], + "raw_data": { + "basename": "25/shard.00000.mds", + "bytes": 244, + "hashes": {} + }, + "samples": 2, + "size_limit": 67108864, + "version": 2, + "zip_data": null + } + ] +} diff --git a/tests/test_util.py b/tests/test_util.py index 4eec145bc..6989e60c2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,26 +1,18 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 +import os +import tempfile from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import List, Optional, Tuple, Dict, Any +from typing import Any, List, Optional, Tuple import pytest from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - number_abbrev_to_int) -from tests.common.utils import convert_to_mds - -from pyspark.sql import SparkSession -from pyspark.sql.functions import col -from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + merge_index, number_abbrev_to_int) -from streaming.base.converters import dataframeToMDS -from streaming.base.util import merge_index -import tempfile -import os -import glob @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) @@ -117,59 +109,37 @@ def test_clean_stale_shared_memory(): _ = BuiltinSharedMemory(name, False, 64) -@pytest.fixture -def dataframe(): - spark = SparkSession.builder.getOrCreate() # pyright: ignore - - data = [('36636', 'Finance', (3000, 'USA')), ('40288', 'Finance', (5000, 'IND')), - ('42114', 'Sales', (3900, 'USA')), ('39192', 'Marketing', (2500, 'CAN')), - ('34534', 'Sales', (6500, 'USA'))] - schema = StructType([ - StructField('id', StringType(), True), - StructField('dept', StringType(), True), - StructField( - 'properties', - StructType([ - StructField('salary', IntegerType(), True), - StructField('location', StringType(), True) - ])) - ]) - - df = spark.createDataFrame(data=data, schema=schema).repartition(3) - yield df - -"""Example input urls to test - ['gs://mybucket/mdsdata/25/'...] - ['/path/never/exists/25',... ] - [('/path/never/exists/25', 'gs://mybucket/mdsdata/25/'), ...] - [('tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), ...] -""" -@pytest.mark.parametrize('folder_urls', ['local_accessible', 'remote', 'local_unaccessible', 'local_accessible_tuple']) +@pytest.mark.parametrize( + 'folder_urls', ['local_accessible', 'remote', 'local_unaccessible', 'local_accessible_tuple']) @pytest.mark.parametrize('out', ['local_str', 'remote_str', 'tuple']) @pytest.mark.usefixtures('local_remote_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_merge_index(local_remote_dir: Tuple[str, str], - dataframe: Any, - keep_local: bool, - folder_urls: Dict, - out: Dict): - - naive_mds_partitions= ['tests/resources/naive_MDSdataset/25/', - 'tests/resources/naive_MDSdataset/26/', - 'tests/resources/naive_MDSdataset/27/'] +def test_merge_index(local_remote_dir: Tuple[str, str], dataframe: Any, keep_local: Any, + folder_urls: Any, out: Any): + """Example input urls to test + ['gs://mybucket/mdsdata/25/'...] + ['/path/never/exists/25',... ] + [('/path/never/exists/25', 'gs://mybucket/mdsdata/25/'), ...] + [('tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), ...] + """ + + naive_mds_partitions = [ + 'tests/resources/naive_MDSdataset/25/', 'tests/resources/naive_MDSdataset/26/', + 'tests/resources/naive_MDSdataset/27/' + ] if folder_urls == 'local_accessible': - folder_urls = [ os.getcwd() + '/' + s for s in naive_mds_partitions] + folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] print(folder_urls) if out == 'local_str': with tempfile.TemporaryDirectory() as tmp: - n_downloads = merge_index(folder_urls, tmp, keep_local = keep_local, overwrite = True) + n_downloads = merge_index(folder_urls, tmp, keep_local=keep_local, overwrite=True) if keep_local: - assert(os.path.exists(os.path.join(tmp, 'index.json'))) + assert (os.path.exists(os.path.join(tmp, 'index.json'))) else: - assert(not os.path.exists(os.path.join(tmp, 'index.json'))) - assert n_downloads == 0, f"n_downloads should be 0 instead of {n_downloads}" + assert (not os.path.exists(os.path.join(tmp, 'index.json'))) + assert n_downloads == 0, f'n_downloads should be 0 instead of {n_downloads}' else: return @@ -178,20 +148,22 @@ def test_merge_index(local_remote_dir: Tuple[str, str], if folder_urls == 'local_unaccessible': with tempfile.TemporaryDirectory() as tmp_data_root: - folder_urls = [ tmp_data_root + '/' + s for s in naive_mds_partitions] - with pytest.raises(FileNotFoundError, match=f'.* does not exit or cannot be acceessed by the current process.*'): - merge_index(folder_urls, tmp_data_root, keep_local = keep_local, overwrite = True) + folder_urls = [tmp_data_root + '/' + s for s in naive_mds_partitions] + with pytest.raises( + FileNotFoundError, + match=f'.* does not exit or cannot be acceessed by the current process.*'): + merge_index(folder_urls, tmp_data_root, keep_local=keep_local, overwrite=True) if folder_urls == 'local_accessible_tuple': folder_urls = [] for s in naive_mds_partitions: - folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/'+ s )) + folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) if out == 'local_str': with tempfile.TemporaryDirectory() as tmp_data_root: - folder_urls = [ tmp_data_root + '/' + s for s in naive_mds_partitions] - with pytest.raises(FileNotFoundError, match=f'.* does not exit or cannot be acceessed by the current process.*'): - merge_index(folder_urls, tmp_data_root, keep_local = keep_local, overwrite = True) + folder_urls = [tmp_data_root + '/' + s for s in naive_mds_partitions] + with pytest.raises( + FileNotFoundError, + match=f'.* does not exit or cannot be acceessed by the current process.*'): + merge_index(folder_urls, tmp_data_root, keep_local=keep_local, overwrite=True) else: return - - From b4a0ff7c72dfb82bac24cd92a3c3fd71dc43207f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 14:20:04 -0700 Subject: [PATCH 04/59] Fix --- tests/test_util.py | 69 +++++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 6989e60c2..77fb1bb50 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -109,18 +109,24 @@ def test_clean_stale_shared_memory(): _ = BuiltinSharedMemory(name, False, 64) -@pytest.mark.parametrize( - 'folder_urls', ['local_accessible', 'remote', 'local_unaccessible', 'local_accessible_tuple']) +@pytest.mark.parametrize('folder_urls', [ + 'local_accessible', 'remote', 'local_unaccessible', 'tuple_local_accessible', + 'tuple_partial_local_accessible' +]) @pytest.mark.parametrize('out', ['local_str', 'remote_str', 'tuple']) @pytest.mark.usefixtures('local_remote_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_merge_index(local_remote_dir: Tuple[str, str], dataframe: Any, keep_local: Any, - folder_urls: Any, out: Any): - """Example input urls to test - ['gs://mybucket/mdsdata/25/'...] - ['/path/never/exists/25',... ] - [('/path/never/exists/25', 'gs://mybucket/mdsdata/25/'), ...] - [('tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), ...] +def test_merge_index(local_remote_dir: Tuple[str, str], keep_local: Any, folder_urls: Any, + out: Any): + """Example input urls to test based on folder accessibility to the driver + ['pwd()/tests/resources/naive_MDSdataset/25/', ...] --> all accessible locally + ['gs://mybucket/mdsdata/25/'...] --> all remote urls, assume all accessible remotely + ['/path/never/exists/25',... ] --> none accessible locally + [('/path/never/exists/25', 'gs://mybucket/mdsdata/25/'), ...] --> all accessible remotely but not locally + [('pwd()/tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), ...] --> all accessible both locally and remotely + [('pwd()/tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), + ('/path/never/exists/25/', 'gs://mybucket/mdsdata/25/'), + ...] --> some accessible both locally and remotely, some are not """ naive_mds_partitions = [ @@ -128,23 +134,25 @@ def test_merge_index(local_remote_dir: Tuple[str, str], dataframe: Any, keep_loc 'tests/resources/naive_MDSdataset/27/' ] + # require download_file from cloud + if out == 'remote_str' or 'tuple': + return + + # require download_file from cloud + if folder_urls == 'remote' or 'tuple_partial_local_accessible': + return + if folder_urls == 'local_accessible': folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] print(folder_urls) - if out == 'local_str': - with tempfile.TemporaryDirectory() as tmp: - n_downloads = merge_index(folder_urls, tmp, keep_local=keep_local, overwrite=True) - if keep_local: - assert (os.path.exists(os.path.join(tmp, 'index.json'))) - else: - assert (not os.path.exists(os.path.join(tmp, 'index.json'))) - assert n_downloads == 0, f'n_downloads should be 0 instead of {n_downloads}' - else: - return - - if folder_urls == 'remote': - return + with tempfile.TemporaryDirectory() as tmp: + n_downloads = merge_index(folder_urls, tmp, keep_local=keep_local, overwrite=True) + if keep_local: + assert (os.path.exists(os.path.join(tmp, 'index.json'))) + else: + assert (not os.path.exists(os.path.join(tmp, 'index.json'))) + assert n_downloads == 0, f'n_downloads should be 0 instead of {n_downloads}' if folder_urls == 'local_unaccessible': with tempfile.TemporaryDirectory() as tmp_data_root: @@ -154,16 +162,13 @@ def test_merge_index(local_remote_dir: Tuple[str, str], dataframe: Any, keep_loc match=f'.* does not exit or cannot be acceessed by the current process.*'): merge_index(folder_urls, tmp_data_root, keep_local=keep_local, overwrite=True) - if folder_urls == 'local_accessible_tuple': + if folder_urls == 'tuple_local_accessible': folder_urls = [] for s in naive_mds_partitions: folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) - if out == 'local_str': - with tempfile.TemporaryDirectory() as tmp_data_root: - folder_urls = [tmp_data_root + '/' + s for s in naive_mds_partitions] - with pytest.raises( - FileNotFoundError, - match=f'.* does not exit or cannot be acceessed by the current process.*'): - merge_index(folder_urls, tmp_data_root, keep_local=keep_local, overwrite=True) - else: - return + with tempfile.TemporaryDirectory() as tmp_data_root: + n_downloads = merge_index(folder_urls, + tmp_data_root, + keep_local=keep_local, + overwrite=True) + assert n_downloads == 0, f'n_downloads should be 0 instead of {n_downloads}' From 2f37037d3783b8c4f66b25054c435c53206dead4 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 14:47:16 -0700 Subject: [PATCH 05/59] Change dataframeToMDS API to use merge_util helper --- streaming/base/converters/dataframe_to_mds.py | 53 +++---------------- 1 file changed, 6 insertions(+), 47 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 0f5d71071..cd53975e0 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -13,6 +13,8 @@ import pandas as pd from streaming.base.util import get_import_exception_message +from streaming.base.util import merge_index as do_merge_index + try: from pyspark import TaskContext @@ -119,47 +121,6 @@ def map_spark_dtype(spark_data_type: Any) -> str: return schema_dict -def do_merge_index(partitions: Iterable, mds_path: Union[str, Tuple[str, str]]) -> None: - """Merge index.json from partitions into one for streaming. - - Args: - partitions (Iterable): partitions that contain pd.DataFrame - mds_path (Union[str, Tuple[str, str]]): (str,str)=(local,remote), str = local or remote - based on parse_uri(url) result - """ - if not partitions: - logger.warning('No partitions exist, no index merged') - return - - shards = [] - - for row in partitions: - mds_partition_index = f'{row.mds_path}/{get_index_basename()}' - mds_partition_basename = os.path.basename(row.mds_path) - obj = json.load(open(mds_partition_index)) - for i in range(len(obj['shards'])): - shard = obj['shards'][i] - for key in ('raw_data', 'zip_data'): - if shard.get(key): - basename = shard[key]['basename'] - obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename, - basename) - shards += obj['shards'] - - obj = { - 'version': 2, - 'shards': shards, - } - - if isinstance(mds_path, str): - mds_index = os.path.join(mds_path, get_index_basename()) - else: - mds_index = os.path.join(mds_path[0], get_index_basename()) - - with open(mds_index, 'w') as out: - json.dump(obj, out) - - def dataframeToMDS(dataframe: DataFrame, merge_index: bool = True, mds_kwargs: Optional[Dict[str, Any]] = None, @@ -203,10 +164,8 @@ def write_mds(iterator: Iterable): if isinstance(mds_path, str): # local output = os.path.join(mds_path, f'{id}') - out_file_path = output else: output = (os.path.join(mds_path[0], f'{id}'), os.path.join(mds_path[1], f'{id}')) - out_file_path = output[0] if mds_kwargs: kwargs = mds_kwargs.copy() @@ -238,7 +197,7 @@ def write_mds(iterator: Iterable): count += 1 yield pd.concat( - [pd.Series([out_file_path], name='mds_path'), + [pd.Series([out_put], name='mds_path'), pd.Series([count], name='fail_count')], axis=1) @@ -289,11 +248,11 @@ def write_mds(iterator: Iterable): partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() if merge_index: - do_merge_index(partitions, mds_path) + folder_urls = [ row['mds_path'] for row in partitions ] + n_downloads = do_merge_index(folder_urls, out, keep_local = keep_local, overwrite=True) + logger.warning(f"{n_download} index files have been downloaded during index merging") if cu.remote is not None: - if merge_index: - cu.upload_file(get_index_basename()) if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: shutil.rmtree(cu.local, ignore_errors=True) From ed9e8d05289a93363cd1ce1494843fbf6eada8cc Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 15:39:25 -0700 Subject: [PATCH 06/59] Fix unit tests --- streaming/base/converters/dataframe_to_mds.py | 17 ++-- streaming/base/util.py | 6 +- .../base/converters/test_dataframe_to_mds.py | 78 +++++++++++-------- 3 files changed, 62 insertions(+), 39 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index cd53975e0..9b25912f1 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -163,9 +163,10 @@ def write_mds(iterator: Iterable): raise RuntimeError('TaskContext.get() returns None') if isinstance(mds_path, str): # local - output = os.path.join(mds_path, f'{id}') + output_path = output = os.path.join(mds_path, f'{id}') else: output = (os.path.join(mds_path[0], f'{id}'), os.path.join(mds_path[1], f'{id}')) + output_path = ','.join(output) if mds_kwargs: kwargs = mds_kwargs.copy() @@ -174,7 +175,7 @@ def write_mds(iterator: Iterable): kwargs = {} if merge_index: - kwargs['keep_local'] = True # need to keep local to do merge + kwargs['keep_local'] = True # need to keep workers' locals to do merge count = 0 @@ -197,7 +198,7 @@ def write_mds(iterator: Iterable): count += 1 yield pd.concat( - [pd.Series([out_put], name='mds_path'), + [pd.Series([output_path], name='mds_path'), pd.Series([count], name='fail_count')], axis=1) @@ -233,6 +234,7 @@ def write_mds(iterator: Iterable): out = mds_kwargs['out'] keep_local = False if 'keep_local' not in mds_kwargs else mds_kwargs['keep_local'] cu = CloudUploader.get(out, keep_local=keep_local) + print('cu.local = ', cu.local) # Fix output format as mds_path: Tuple => remote Str => local only if cu.remote is None: @@ -248,9 +250,14 @@ def write_mds(iterator: Iterable): partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() if merge_index: - folder_urls = [ row['mds_path'] for row in partitions ] + folder_urls = [] + for row in partitions: + if ',' in row['mds_path']: + folder_urls.append(row['mds_path'].split(',')) + else: + folder_urls.append(row['mds_path']) n_downloads = do_merge_index(folder_urls, out, keep_local = keep_local, overwrite=True) - logger.warning(f"{n_download} index files have been downloaded during index merging") + logger.warning(f"{n_downloads} index files have been downloaded during index merging") if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/streaming/base/util.py b/streaming/base/util.py index e4510451b..974eed1e2 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -262,7 +262,6 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], # Determine if we need to call download_file. download = False for url in urls: - local = remote = url if type(url) is tuple: # If driver cannot access the local path, download = True download = not os.path.exists(url[0]) @@ -282,7 +281,10 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], partitions = [] n_downloads = 0 for url in urls: - local = remote = url + if type(url) is tuple: + local, remote = url + else: + local = remote = url if download: # If download is needed, download url from remote to temp_root diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 418fd9490..a15db2562 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -18,7 +18,7 @@ MY_PREFIX = 'train' MY_BUCKET = 'mosaicml-composer-tests' -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -27,7 +27,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = mkdtemp() @@ -96,6 +97,8 @@ def dataframe(self): def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: bool, merge_index: bool, local_remote_dir: Tuple[str, str]): + print('keep local = ', keep_local) + print('merge_index = ', merge_index) out, _ = local_remote_dir mds_kwargs = { 'out': out, @@ -111,23 +114,27 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: merge_index=merge_index, mds_kwargs=mds_kwargs) - assert (len(os.listdir(out)) > 0), f'{out} is empty' - for d in os.listdir(out): - if os.path.isdir(os.path.join(out, d)): - assert (os.path.exists(os.path.join( - out, d, 'index.json'))), f'No index.json found in subdirectory {d}' + if keep_local: + assert (len(os.listdir(out)) > 0), f'{out} is empty' + for d in os.listdir(out): + if os.path.isdir(os.path.join(out, d)): + assert (os.path.exists(os.path.join( + out, d, 'index.json'))), f'No index.json found in subdirectory {d}' if merge_index: - assert (os.path.exists(os.path.join(out, 'index.json'))), 'No merged index.json found' - mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) - nsamples = 0 - for d in os.listdir(out): - sub_dir = os.path.join(out, d) - if os.path.isdir(sub_dir): - shards = json.load(open(os.path.join(sub_dir, 'index.json'), 'r'))['shards'] - if shards: - nsamples += shards[0]['samples'] - assert (nsamples == sum([a['samples'] for a in mgi['shards']])) + if keep_local: + assert (os.path.exists(os.path.join(out, 'index.json'))), 'No merged index.json found' + mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) + nsamples = 0 + for d in os.listdir(out): + sub_dir = os.path.join(out, d) + if os.path.isdir(sub_dir): + shards = json.load(open(os.path.join(sub_dir, 'index.json'), 'r'))['shards'] + if shards: + nsamples += shards[0]['samples'] + assert (nsamples == sum([a['samples'] for a in mgi['shards']])) + if not keep_local: + assert (not os.path.exists(os.path.join(out, 'index.json'))), 'merged index.json is found even keep_local = False' else: assert not (os.path.exists(os.path.join( out, 'index.json'))), 'merged index is created when merge_index=False' @@ -140,6 +147,7 @@ def test_end_to_end_conversion_local_decimal(self, decimal_dataframe: Any, use_c mds_kwargs = { 'out': out, 'columns': user_defined_columns, + 'keep_local' : True } if use_columns: @@ -183,26 +191,32 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer 'hashes': ['sha1', 'xxh64'], 'size_limit': 1 << 26 } + print('keep_local = ', keep_local) + print('merge_index = ', merge_index) _, _ = dataframeToMDS(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs) - assert (len(os.listdir(out)) > 0), f'{out} is empty' - for d in os.listdir(out): - if os.path.isdir(os.path.join(out, d)): - assert (os.path.exists(os.path.join( - out, d, 'index.json'))), f'No index.json found in subdirectory {d}' + if keep_local: + assert (len(os.listdir(out)) > 0), f'{out} is empty' + for d in os.listdir(out): + if os.path.isdir(os.path.join(out, d)): + assert (os.path.exists(os.path.join( + out, d, 'index.json'))), f'No index.json found in subdirectory {d}' if merge_index == True: - assert (os.path.exists(os.path.join(out, 'index.json'))), 'No merged index.json found' - mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) - nsamples = 0 - for d in os.listdir(out): - sub_dir = os.path.join(out, d) - if os.path.isdir(sub_dir): - shards = json.load(open(os.path.join(sub_dir, 'index.json'), 'r'))['shards'] - if shards: - nsamples += shards[0]['samples'] - assert (nsamples == sum([a['samples'] for a in mgi['shards']])) + if keep_local: + assert (os.path.exists(os.path.join(out, 'index.json'))), 'No merged index.json found' + mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) + nsamples = 0 + for d in os.listdir(out): + sub_dir = os.path.join(out, d) + if os.path.isdir(sub_dir): + shards = json.load(open(os.path.join(sub_dir, 'index.json'), 'r'))['shards'] + if shards: + nsamples += shards[0]['samples'] + assert (nsamples == sum([a['samples'] for a in mgi['shards']])) + else: + assert (not os.path.exists(os.path.join(out, 'index.json'))), 'merged index.json is found even keep_local=False' else: assert not (os.path.exists(os.path.join( out, 'index.json'))), 'merged index is created when merge_index=False' From af9b6dd020ffe78db080936e2bd91c3297282e77 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 15:45:51 -0700 Subject: [PATCH 07/59] Fix tests --- tests/base/converters/test_dataframe_to_mds.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a15db2562..7a2894971 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -18,7 +18,7 @@ MY_PREFIX = 'train' MY_BUCKET = 'mosaicml-composer-tests' -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -27,8 +27,7 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' tmp_dir = mkdtemp() From c45ceb963d4b728968178b3ce71185f5b821a150 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 26 Sep 2023 15:47:06 -0700 Subject: [PATCH 08/59] Fix lints --- streaming/base/converters/dataframe_to_mds.py | 9 +++---- .../base/converters/test_dataframe_to_mds.py | 24 ++++++++++--------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 9b25912f1..4f4b551f3 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -3,19 +3,17 @@ """A utility to convert spark dataframe to MDS.""" -import json import logging import os import shutil from collections.abc import Iterable -from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple import pandas as pd from streaming.base.util import get_import_exception_message from streaming.base.util import merge_index as do_merge_index - try: from pyspark import TaskContext from pyspark.sql.dataframe import DataFrame @@ -28,7 +26,6 @@ raise e from streaming import MDSWriter -from streaming.base.format.index import get_index_basename from streaming.base.format.mds.encodings import _encodings from streaming.base.storage.upload import CloudUploader @@ -256,8 +253,8 @@ def write_mds(iterator: Iterable): folder_urls.append(row['mds_path'].split(',')) else: folder_urls.append(row['mds_path']) - n_downloads = do_merge_index(folder_urls, out, keep_local = keep_local, overwrite=True) - logger.warning(f"{n_downloads} index files have been downloaded during index merging") + n_downloads = do_merge_index(folder_urls, out, keep_local=keep_local, overwrite=True) + logger.warning(f'{n_downloads} index files have been downloaded during index merging') if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 7a2894971..a0689460d 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -122,18 +122,21 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: if merge_index: if keep_local: - assert (os.path.exists(os.path.join(out, 'index.json'))), 'No merged index.json found' + assert (os.path.exists(os.path.join(out, + 'index.json'))), 'No merged index.json found' mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) nsamples = 0 for d in os.listdir(out): sub_dir = os.path.join(out, d) if os.path.isdir(sub_dir): - shards = json.load(open(os.path.join(sub_dir, 'index.json'), 'r'))['shards'] + shards = json.load(open(os.path.join(sub_dir, 'index.json'), + 'r'))['shards'] if shards: nsamples += shards[0]['samples'] assert (nsamples == sum([a['samples'] for a in mgi['shards']])) if not keep_local: - assert (not os.path.exists(os.path.join(out, 'index.json'))), 'merged index.json is found even keep_local = False' + assert (not os.path.exists(os.path.join( + out, 'index.json'))), 'merged index.json is found even keep_local = False' else: assert not (os.path.exists(os.path.join( out, 'index.json'))), 'merged index is created when merge_index=False' @@ -143,11 +146,7 @@ def test_end_to_end_conversion_local_decimal(self, decimal_dataframe: Any, use_c local_remote_dir: Tuple[str, str]): out, _ = local_remote_dir user_defined_columns = {'id': 'int', 'name': 'str', 'amount': 'str_decimal'} - mds_kwargs = { - 'out': out, - 'columns': user_defined_columns, - 'keep_local' : True - } + mds_kwargs = {'out': out, 'columns': user_defined_columns, 'keep_local': True} if use_columns: mds_kwargs['columns'] = user_defined_columns @@ -204,18 +203,21 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer if merge_index == True: if keep_local: - assert (os.path.exists(os.path.join(out, 'index.json'))), 'No merged index.json found' + assert (os.path.exists(os.path.join(out, + 'index.json'))), 'No merged index.json found' mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) nsamples = 0 for d in os.listdir(out): sub_dir = os.path.join(out, d) if os.path.isdir(sub_dir): - shards = json.load(open(os.path.join(sub_dir, 'index.json'), 'r'))['shards'] + shards = json.load(open(os.path.join(sub_dir, 'index.json'), + 'r'))['shards'] if shards: nsamples += shards[0]['samples'] assert (nsamples == sum([a['samples'] for a in mgi['shards']])) else: - assert (not os.path.exists(os.path.join(out, 'index.json'))), 'merged index.json is found even keep_local=False' + assert (not os.path.exists(os.path.join( + out, 'index.json'))), 'merged index.json is found even keep_local=False' else: assert not (os.path.exists(os.path.join( out, 'index.json'))), 'merged index is created when merge_index=False' From 8b7db391f7c1e54f01a935f6ebc0174ee869a936 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 28 Sep 2023 11:52:28 -0700 Subject: [PATCH 09/59] Address a few comments --- streaming/base/util.py | 32 +++++++------------ .../base/converters/test_dataframe_to_mds.py | 4 --- tests/test_util.py | 12 +++---- 3 files changed, 17 insertions(+), 31 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 974eed1e2..d6d9fc652 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -206,16 +206,15 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], - out: Union[str, Tuple[str, str]], +def merge_index(root_to_MDSdataset: Union[str, Tuple[str, str]], *, keep_local: bool = True, overwrite: bool = True, - download_timeout: int = 100) -> int: + download_timeout: int = 60) -> int: """Merge index.json from a list of remote or local directories. Args: - folder_urls (Iterable): folders that contain index.json for the partition + root_to_MDSdataset (Union[str, Tuple[str,str]]): folders that contain index.json for the partition each element can take the form of a single path string or a tuple string for each url in folder_urls, if url is @@ -228,12 +227,8 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], out (Union[str, Tuple[str, str]]): path to put the merged index file keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - overwrite (bool): Overwrite merged index file in out if there exists one.Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file - defaults to 60, same as streaming.download_file - - Returns: - int: count of files downloaded during function call + overwrite (bool): Overwrite merged index file in out if there exists one. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ # Import here to avoid circular import error from streaming.base.storage.download import download_file @@ -254,7 +249,7 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], # Remove '/' from right, so os.path.basename gives relative path to each folder urls = [] for url in folder_urls: - if type(url) is str: + if isinstance(url, str): urls.append(url.rstrip('/').strip()) else: urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) @@ -262,7 +257,7 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], # Determine if we need to call download_file. download = False for url in urls: - if type(url) is tuple: + if isinstance(url, tuple): # If driver cannot access the local path, download = True download = not os.path.exists(url[0]) else: @@ -273,15 +268,14 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], if download: break - print('download = ', download) - # Prepare a temp folder to download index.json rom remote if necessary. Removed in the end. + # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: + logging.warning(f"Create a temporary folder {temp_root} to store index files") # container for absolute local folder path partitions = [] - n_downloads = 0 for url in urls: - if type(url) is tuple: + if isinstance(url, tuple): local, remote = url else: local = remote = url @@ -294,13 +288,12 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], remote_url = os.path.join(remote, index_basename) local_path = os.path.join(local, index_basename) download_file(remote_url, local_path, download_timeout) - n_downloads += 1 except Exception as ex: - raise RuntimeError(f'failed to download index.json {url}') from ex + raise RuntimeError(f'Failed to download index.json {url}') from ex if not (os.path.exists(local)): raise FileNotFoundError( - 'Folder {local} does not exit or cannot be acceessed by the current process') + 'Folder {local} does not exit or cannot be acceessed.') partitions.append(local) # merge index files into shards @@ -338,4 +331,3 @@ def merge_index(folder_urls: List[Union[str, Tuple[str, str]]], if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) - return n_downloads diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a0689460d..2dacdb0a4 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -96,8 +96,6 @@ def dataframe(self): def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: bool, merge_index: bool, local_remote_dir: Tuple[str, str]): - print('keep local = ', keep_local) - print('merge_index = ', merge_index) out, _ = local_remote_dir mds_kwargs = { 'out': out, @@ -189,8 +187,6 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer 'hashes': ['sha1', 'xxh64'], 'size_limit': 1 << 26 } - print('keep_local = ', keep_local) - print('merge_index = ', merge_index) _, _ = dataframeToMDS(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs) diff --git a/tests/test_util.py b/tests/test_util.py index 77fb1bb50..5ec0595f4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -147,12 +147,11 @@ def test_merge_index(local_remote_dir: Tuple[str, str], keep_local: Any, folder_ print(folder_urls) with tempfile.TemporaryDirectory() as tmp: - n_downloads = merge_index(folder_urls, tmp, keep_local=keep_local, overwrite=True) + merge_index(folder_urls, tmp, keep_local=keep_local, overwrite=True) if keep_local: assert (os.path.exists(os.path.join(tmp, 'index.json'))) else: assert (not os.path.exists(os.path.join(tmp, 'index.json'))) - assert n_downloads == 0, f'n_downloads should be 0 instead of {n_downloads}' if folder_urls == 'local_unaccessible': with tempfile.TemporaryDirectory() as tmp_data_root: @@ -167,8 +166,7 @@ def test_merge_index(local_remote_dir: Tuple[str, str], keep_local: Any, folder_ for s in naive_mds_partitions: folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) with tempfile.TemporaryDirectory() as tmp_data_root: - n_downloads = merge_index(folder_urls, - tmp_data_root, - keep_local=keep_local, - overwrite=True) - assert n_downloads == 0, f'n_downloads should be 0 instead of {n_downloads}' + merge_index(folder_urls, + tmp_data_root, + keep_local=keep_local, + overwrite=True) From 091518cf8a40dd9ec6054528235f9f1fff685201 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 28 Sep 2023 21:35:52 -0700 Subject: [PATCH 10/59] update --- streaming/base/util.py | 59 +++++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index d6d9fc652..c11cffca3 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -20,8 +20,14 @@ from streaming.base.format.index import get_index_basename from streaming.base.shared.prefix import _get_path +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from streaming.base.storage.download import download_file + from streaming.base.storage.upload import CloudUploader + logger = logging.getLogger(__name__) + __all__ = ['get_list_arg'] @@ -206,11 +212,13 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def merge_index(root_to_MDSdataset: Union[str, Tuple[str, str]], - *, - keep_local: bool = True, - overwrite: bool = True, - download_timeout: int = 60) -> int: +def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], + *, + out: Union[str, Tuple[str,str]], + keep_local: bool = True, + overwrite: bool = True, + download_timeout: int = 60) -> int: + """Merge index.json from a list of remote or local directories. Args: @@ -230,10 +238,6 @@ def merge_index(root_to_MDSdataset: Union[str, Tuple[str, str]], overwrite (bool): Overwrite merged index file in out if there exists one. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ - # Import here to avoid circular import error - from streaming.base.storage.download import download_file - from streaming.base.storage.upload import CloudUploader - if not folder_urls: logger.warning('No partitions exist, no index merged') return 0 @@ -331,3 +335,40 @@ def merge_index(root_to_MDSdataset: Union[str, Tuple[str, str]], if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) +def merge_index(root_to_mds: Union[str, Tuple[str, str]], + *, + keep_local: bool = True, + overwrite: bool = True) -> int: + """Merge index.json of a MDS dataset from the subdirectories of root_to_mds and store in root_to_mds + + Args: + root_to_MDSdataset (Union[str, Tuple[str,str]]): folders that contain MDS partitions. + It can be local str or remote str or (local, remote) + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + overwrite (bool): Overwrite merged index file in out if there exists one. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + """ + + if not root_to_mds: + logger.warning('No MDS dataset folder specified, no index merged') + return + + if isinstance(root_to_mds, tuple): + local_folders = list_objects(root_to_mds[0]) + remote_folders = list_objects(root_to_mds[1]) + folder_urls = list(zip(local_folders, remote_folders)) + else: + folder_urls = list_objects(root_to_mds) + + do_merge_index(folder_urls, + root_to_mds, + keep_local=keep_local, + overwrite=overwrite, + download_timeout = 60) + + + + + + + From b0219c40fdd4b993ae0f0f6018a78fc751d68b7f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 28 Sep 2023 22:11:57 -0700 Subject: [PATCH 11/59] updates --- streaming/base/storage/upload.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 47d794fd4..3a0e1885b 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -75,7 +75,7 @@ def get( shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. @@ -139,7 +139,7 @@ def __init__(self, shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -202,7 +202,7 @@ class S3Uploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -285,7 +285,7 @@ class GCSUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -396,7 +396,7 @@ class OCIUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -561,7 +561,7 @@ class AzureDataLakeUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -638,7 +638,7 @@ class DatabricksUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -674,7 +674,7 @@ class DatabricksUnityCatalogUploader(DatabricksUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -715,7 +715,7 @@ class DBFSUploader(DatabricksUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -777,7 +777,7 @@ class LocalUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, From 1678844bd3ccab1216aa910cec9de3e8070fdf52 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 28 Sep 2023 23:03:23 -0700 Subject: [PATCH 12/59] Address comments --- streaming/base/util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 462c392d9..3779c55bb 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -314,7 +314,7 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], obj = json.load(open(partition_index)) for i in range(len(obj['shards'])): shard = obj['shards'][i] - for key in ('raw_data', 'zip_data'): + for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): if shard.get(key): basename = shard[key]['basename'] obj['shards'][i][key]['basename'] = os.path.join( @@ -337,7 +337,6 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], cu.upload_file(index_basename) # Clean up - # shutil.rmtree(temp_root, ignore_errors=True) if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) From 400050ed126ddd15c0af1534c1077dc2db302391 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 00:23:50 -0700 Subject: [PATCH 13/59] update unit tests --- streaming/base/storage/download.py | 34 +++++++++++ streaming/base/storage/upload.py | 4 +- streaming/base/util.py | 61 +++++++++---------- tests/test_util.py | 95 +++++++++++++++--------------- 4 files changed, 110 insertions(+), 84 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 9db4af328..70426dd1a 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -489,3 +489,37 @@ def wait_for_download(local: str, timeout: float = 60) -> None: raise TimeoutError( f'Waited longer than {timeout}s for other worker to download {local}.') sleep(0.25) + + +def list_objects(remote: Optional[str]): + """Use the correct cloud handler to list objects. + + Args: + remote (str, optional): Remote path (local filesystem). + If remote is None or '', list current working directory with os.listdir() + """ + # fix paths for windows + if remote: + remote = remote.replace('\\', '/') + + if not remote: # '' or None + return os.listdir() + elif remote.startswith('s3://'): + list_objects_from_s3(remote) + elif remote.startswith('sftp://'): + raise NotImplemented('list_objects for sftp not supported') + elif remote.startswith('gs://'): + list_objects_from_gcs(remote, local) + elif remote.startswith('oci://'): + list_objects_from_oci(remote, local) + elif remote.startswith('azure://'): + raise NotImplemented('list_objects for azure not supported') + elif remote.startswith('azure-dl://'): + raise NotImplemented('list_objects for azure-dl not supported') + elif remote.startswith('dbfs:/Volumes'): + raise NotImplemented('list_objects for dbfs:/Volumes not supported') + elif remote.startswith('dbfs:/'): + raise NotImplemented('list_objects for dbfs:/ not supported') + else: + raise ValueError("remote scheme is not recognizable") + diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 3567ea61a..e5ac2eb6b 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -89,8 +89,8 @@ def get( prefix = os.path.join(path.parts[0], path.parts[1]) if prefix == 'dbfs:/Volumes': provider_prefix = prefix - return getattr(sys.modules[__name__], UPLOADERS[provider_prefix])(out, keep_local, - progress_bar, retry, exist_ok) + return getattr(sys.modules[__name__], + UPLOADERS[provider_prefix])(out, keep_local, progress_bar, retry, exist_ok) def _validate(self, out: Union[str, Tuple[str, str]]) -> None: """Validate the `out` argument. diff --git a/streaming/base/util.py b/streaming/base/util.py index 3779c55bb..a58b765fe 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -5,16 +5,17 @@ import collections.abc import functools +import json import logging import os import random -from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from time import sleep, time -from typing import Any, Callable, List, Sequence, Type, TypeVar, Union, cast, overload, Tuple -import json import shutil import tempfile import urllib.parse +from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory +from time import sleep, time +from typing import (TYPE_CHECKING, Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, + cast, overload, Union) import torch.distributed as dist @@ -23,10 +24,6 @@ from streaming.base.format.index import get_index_basename from streaming.base.shared.prefix import _get_path -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from streaming.base.storage.download import download_file - from streaming.base.storage.upload import CloudUploader logger = logging.getLogger(__name__) @@ -37,6 +34,7 @@ 'clean_stale_shared_memory', 'get_import_exception_message', 'retry' ] + def get_list_arg(text: str) -> List[str]: """Pass a list as a command-line flag. @@ -219,42 +217,39 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], + out: Union[str, Tuple[str, str]], *, - out: Union[str, Tuple[str,str]], keep_local: bool = True, - overwrite: bool = True, - download_timeout: int = 60) -> int: - + download_timeout: int = 60) -> None: """Merge index.json from a list of remote or local directories. + Write merged index file to `out`. Overwrite if an old file exists. Args: - root_to_MDSdataset (Union[str, Tuple[str,str]]): folders that contain index.json for the partition - each element can take the form of a single path string or a tuple string + folder_urls (Union[str, Tuple[str,str]]): folders that contain index.json for the partition + each element can take the form of a single path string or a tuple string. - for each url in folder_urls, if url is - 1. tuple (local, remote): check if local is accessible. - -> Yes: use local index to merge - -> No: download from remote first, then merge - 2. str (local path): use local path to merge. - raise FileNotFoundError if any local index is not accessible - 3. str (remote url): download to a temp directory first, then merge + The pattern of folder_urls and corresponding reaction is one of: + 1. All urls are str (local). All urls are accessible locally -> no download + 2. All urls are tuple (local, remote). All urls are accessible locally -> no download + 3. All urls are tuple (local, remote). At least one url is not accessible locally -> download all + 4. All urls are str (remote) -> download all out (Union[str, Tuple[str, str]]): path to put the merged index file keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - overwrite (bool): Overwrite merged index file in out if there exists one. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ - if not folder_urls: - logger.warning('No partitions exist, no index merged') + + from streaming.base.storage.download import download_file, list_objects + from streaming.base.storage.upload import CloudUploader + + if not folder_urls or not out: + logger.warning('Need to specify both folder_urls and out. No index merged') return 0 # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - if os.path.exists(os.path.join(cu.local, index_basename)) and overwrite: - logger.warning('Merged index already exists locally. no index merged if overwrite=False') - return 0 # Remove '/' from right, so os.path.basename gives relative path to each folder urls = [] @@ -280,7 +275,7 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: - logging.warning(f"Create a temporary folder {temp_root} to store index files") + logging.warning(f'Create a temporary folder {temp_root} to store index files') # container for absolute local folder path partitions = [] @@ -302,8 +297,7 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], raise RuntimeError(f'Failed to download index.json {url}') from ex if not (os.path.exists(local)): - raise FileNotFoundError( - 'Folder {local} does not exit or cannot be acceessed.') + raise FileNotFoundError('Folder {local} does not exist or not accessible.') partitions.append(local) # merge index files into shards @@ -340,11 +334,13 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], if not keep_local: shutil.rmtree(cu.local, ignore_errors=True) + def merge_index(root_to_mds: Union[str, Tuple[str, str]], *, keep_local: bool = True, overwrite: bool = True) -> int: - """Merge index.json of a MDS dataset from the subdirectories of root_to_mds and store in root_to_mds + """Merge index.json of a MDS dataset from the subdirectories of root_to_mds and store in + root_to_mds. Args: root_to_MDSdataset (Union[str, Tuple[str,str]]): folders that contain MDS partitions. @@ -368,8 +364,7 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, - overwrite=overwrite, - download_timeout = 60) + download_timeout=60) return diff --git a/tests/test_util.py b/tests/test_util.py index c7dd92179..8169d0a43 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,16 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 import os +import json import tempfile from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import pytest from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - number_abbrev_to_int, merge_index, retry) + merge_index, do_merge_index, number_abbrev_to_int, retry) @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), @@ -109,24 +110,18 @@ def test_clean_stale_shared_memory(): _ = BuiltinSharedMemory(name, False, 64) -@pytest.mark.parametrize('folder_urls', [ - 'local_accessible', 'remote', 'local_unaccessible', 'tuple_local_accessible', - 'tuple_partial_local_accessible' -]) -@pytest.mark.parametrize('out', ['local_str', 'remote_str', 'tuple']) +@pytest.mark.parametrize('folder_urls_pattern', [1, 2, 3, 4, 5]) @pytest.mark.usefixtures('local_remote_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_merge_index(local_remote_dir: Tuple[str, str], keep_local: Any, folder_urls: Any, - out: Any): - """Example input urls to test based on folder accessibility to the driver - ['pwd()/tests/resources/naive_MDSdataset/25/', ...] --> all accessible locally - ['gs://mybucket/mdsdata/25/'...] --> all remote urls, assume all accessible remotely - ['/path/never/exists/25',... ] --> none accessible locally - [('/path/never/exists/25', 'gs://mybucket/mdsdata/25/'), ...] --> all accessible remotely but not locally - [('pwd()/tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), ...] --> all accessible both locally and remotely - [('pwd()/tests/resources/naive_MDSdataset/25/', 'gs://mybucket/mdsdata/25/'), - ('/path/never/exists/25/', 'gs://mybucket/mdsdata/25/'), - ...] --> some accessible both locally and remotely, some are not +def test_do_merge_index(local_remote_dir: Tuple[str, str], + keep_local: bool, + folder_urls_pattern: int): + """Validate the final merge index json for following patterns of folder_urls: + 1. All urls are str (local). All urls are accessible locally -> no download + 2. All urls are str (local). At least one url is unaccessible locally -> Error + 3. All urls are tuple (local, remote). All urls are accessible locally -> no download + 4. All urls are tuple (local, remote). At least one url is not accessible locally -> download all + 5. All urls are str (remote) -> download all """ naive_mds_partitions = [ @@ -134,42 +129,44 @@ def test_merge_index(local_remote_dir: Tuple[str, str], keep_local: Any, folder_ 'tests/resources/naive_MDSdataset/27/' ] - # require download_file from cloud - if out == 'remote_str' or 'tuple': + if folder_urls_pattern in [4,5]: + # Require cloud file transfers. Will be covered by integration tests. return - # require download_file from cloud - if folder_urls == 'remote' or 'tuple_partial_local_accessible': - return - - if folder_urls == 'local_accessible': - folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] - print(folder_urls) + with tempfile.TemporaryDirectory() as out: + if folder_urls_pattern == 1: + folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] + do_merge_index(folder_urls, out, keep_local=keep_local) - with tempfile.TemporaryDirectory() as tmp: - merge_index(folder_urls, tmp, keep_local=keep_local, overwrite=True) - if keep_local: - assert (os.path.exists(os.path.join(tmp, 'index.json'))) - else: - assert (not os.path.exists(os.path.join(tmp, 'index.json'))) - if folder_urls == 'local_unaccessible': - with tempfile.TemporaryDirectory() as tmp_data_root: - folder_urls = [tmp_data_root + '/' + s for s in naive_mds_partitions] + if folder_urls_pattern == 2: + folder_urls = [out + '/' + s for s in naive_mds_partitions] with pytest.raises( - FileNotFoundError, - match=f'.* does not exit or cannot be acceessed by the current process.*'): - merge_index(folder_urls, tmp_data_root, keep_local=keep_local, overwrite=True) - - if folder_urls == 'tuple_local_accessible': - folder_urls = [] - for s in naive_mds_partitions: - folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) - with tempfile.TemporaryDirectory() as tmp_data_root: - merge_index(folder_urls, - tmp_data_root, - keep_local=keep_local, - overwrite=True) + FileNotFoundError, + match=f'.* does not exist or not accessible.*'): + do_merge_index(folder_urls, out, keep_local=keep_local) + return + + if folder_urls_pattern == 3: + folder_urls = [] + for s in naive_mds_partitions: + folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) + do_merge_index(folder_urls, out, keep_local=keep_local) + + # Integrity checks + + merged_index_path = os.path.join(out, 'index.json') + + if not keep_local: + assert not os.path.exists(merged_index_path) + return + + assert os.path.exists(merged_index_path) + merged_index = json.load(open(merged_index_path, 'r')) + n_shard_files = len(set([b['raw_data']['basename'] for b in merged_index['shards']])) + assert(n_shard_files == 2), "expected 2 shard files but got {n_shard_files}" + + @pytest.mark.parametrize('with_args', [True, False]) def test_retry(with_args: bool): From 72430be1d8bb0a647224f82e5ccc47aaff4bdd03 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 12:41:52 -0700 Subject: [PATCH 14/59] Update tests --- streaming/base/converters/dataframe_to_mds.py | 3 +- streaming/base/storage/download.py | 31 ++-- streaming/base/storage/upload.py | 15 +- streaming/base/util.py | 30 ++-- tests/test_util.py | 158 ++++++++++++++---- 5 files changed, 164 insertions(+), 73 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 4f4b551f3..7a83f9549 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -253,8 +253,7 @@ def write_mds(iterator: Iterable): folder_urls.append(row['mds_path'].split(',')) else: folder_urls.append(row['mds_path']) - n_downloads = do_merge_index(folder_urls, out, keep_local=keep_local, overwrite=True) - logger.warning(f'{n_downloads} index files have been downloaded during index merging') + do_merge_index(folder_urls, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 70426dd1a..f10d2020f 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -8,7 +8,7 @@ import shutil import urllib.parse from time import sleep, time -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from streaming.base.util import get_import_exception_message @@ -491,7 +491,19 @@ def wait_for_download(local: str, timeout: float = 60) -> None: sleep(0.25) -def list_objects(remote: Optional[str]): +def list_objects_from_s3(remote: str) -> List[str]: + return [] + + +def list_objects_from_gcs(remote: str) -> List[str]: + return [] + + +def list_objects_from_oci(remote: str) -> List[str]: + return [] + + +def list_objects(remote: Optional[str]) -> List[str]: """Use the correct cloud handler to list objects. Args: @@ -502,16 +514,16 @@ def list_objects(remote: Optional[str]): if remote: remote = remote.replace('\\', '/') - if not remote: # '' or None + if not remote: # '' or None return os.listdir() elif remote.startswith('s3://'): - list_objects_from_s3(remote) - elif remote.startswith('sftp://'): - raise NotImplemented('list_objects for sftp not supported') + return list_objects_from_s3(remote) elif remote.startswith('gs://'): - list_objects_from_gcs(remote, local) + return list_objects_from_gcs(remote) elif remote.startswith('oci://'): - list_objects_from_oci(remote, local) + return list_objects_from_oci(remote) + elif remote.startswith('sftp://'): + raise NotImplemented('list_objects for sftp not supported') elif remote.startswith('azure://'): raise NotImplemented('list_objects for azure not supported') elif remote.startswith('azure-dl://'): @@ -521,5 +533,4 @@ def list_objects(remote: Optional[str]): elif remote.startswith('dbfs:/'): raise NotImplemented('list_objects for dbfs:/ not supported') else: - raise ValueError("remote scheme is not recognizable") - + raise ValueError('remote scheme is not recognizable') diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index e5ac2eb6b..a506377fa 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -53,13 +53,12 @@ class CloudUploader: """Upload local files to a cloud storage.""" @classmethod - def get( - cls, - out: Union[str, Tuple[str, str]], - keep_local: bool = False, - progress_bar: bool = False, - exist_ok: bool = False, - ) -> Any: + def get(cls, + out: Union[str, Tuple[str, str]], + keep_local: bool = False, + progress_bar: bool = False, + exist_ok: bool = False, + retry: int = 2) -> Any: """Instantiate a cloud provider uploader or a local uploader based on remote path. Args: @@ -90,7 +89,7 @@ def get( if prefix == 'dbfs:/Volumes': provider_prefix = prefix return getattr(sys.modules[__name__], - UPLOADERS[provider_prefix])(out, keep_local, progress_bar, retry, exist_ok) + UPLOADERS[provider_prefix])(out, keep_local, progress_bar, exist_ok, retry) def _validate(self, out: Union[str, Tuple[str, str]]) -> None: """Validate the `out` argument. diff --git a/streaming/base/util.py b/streaming/base/util.py index a58b765fe..dcf1dfe67 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -14,8 +14,7 @@ import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from time import sleep, time -from typing import (TYPE_CHECKING, Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, - cast, overload, Union) +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload import torch.distributed as dist @@ -24,7 +23,6 @@ from streaming.base.format.index import get_index_basename from streaming.base.shared.prefix import _get_path - logger = logging.getLogger(__name__) TCallable = TypeVar('TCallable', bound=Callable) @@ -216,13 +214,11 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], +def do_merge_index(folder_urls: List[Any], out: Union[str, Tuple[str, str]], - *, keep_local: bool = True, download_timeout: int = 60) -> None: - """Merge index.json from a list of remote or local directories. - Write merged index file to `out`. Overwrite if an old file exists. + """Merge index.json from a list of directories. Write to `out`, overwriting if exists. Args: folder_urls (Union[str, Tuple[str,str]]): folders that contain index.json for the partition @@ -238,13 +234,12 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ - - from streaming.base.storage.download import download_file, list_objects + from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader if not folder_urls or not out: logger.warning('Need to specify both folder_urls and out. No index merged') - return 0 + return # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() @@ -264,7 +259,7 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], for url in urls: if isinstance(url, tuple): # If driver cannot access the local path, download = True - download = not os.path.exists(url[0]) + download = not os.path.exists(os.path.join(url[0], index_basename)) else: # If url is a remote, download = True, False otherwise download = urllib.parse.urlparse(url).scheme != '' @@ -338,17 +333,17 @@ def do_merge_index(folder_urls: List[Union[str, Tuple[str, str]]], def merge_index(root_to_mds: Union[str, Tuple[str, str]], *, keep_local: bool = True, - overwrite: bool = True) -> int: - """Merge index.json of a MDS dataset from the subdirectories of root_to_mds and store in - root_to_mds. + overwrite: bool = True) -> None: + """Merge index.json given the root of MDS dataset. Write merged index to the root folder. Args: - root_to_MDSdataset (Union[str, Tuple[str,str]]): folders that contain MDS partitions. + root_to_mds (Union[str, Tuple[str,str]]): folders that contain MDS partitions. It can be local str or remote str or (local, remote) keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` overwrite (bool): Overwrite merged index file in out if there exists one. Defaults to ``True`` download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ + from streaming.base.storage.download import list_objects if not root_to_mds: logger.warning('No MDS dataset folder specified, no index merged') @@ -361,10 +356,7 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], else: folder_urls = list_objects(root_to_mds) - do_merge_index(folder_urls, - root_to_mds, - keep_local=keep_local, - download_timeout=60) + do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) return diff --git a/tests/test_util.py b/tests/test_util.py index 8169d0a43..cf8cf002b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,8 +1,9 @@ # Copyright 2023 MosaicML Streaming authors # SPDX-License-Identifier: Apache-2.0 -import os import json +import os +import shutil import tempfile from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from typing import Any, List, Optional, Tuple, Union @@ -11,8 +12,47 @@ from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - merge_index, do_merge_index, number_abbrev_to_int, retry) +from streaming.base.storage.download import download_file +from streaming.base.storage.upload import CloudUploader +from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, + get_list_arg, number_abbrev_to_int, retry) + +MY_PREFIX = 'train' +MY_BUCKET = 'mosaicml-composer-tests' +MANUAL_INTEGRATION_TEST = True +os.environ[ + 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls + + +@pytest.fixture(scope='function', autouse=True) +def manual_integration_dir() -> Any: + """Creates a temporary directory and then deletes it when the calling function is done.""" + if MANUAL_INTEGRATION_TEST: + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + + tmp_dir = tempfile.mkdtemp() + + def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: + mock_local_dir = tmp_dir + mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX) + return mock_local_dir, mock_remote_dir + + try: + yield _method + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) # pyright: ignore + if MANUAL_INTEGRATION_TEST: + try: + from google.cloud.storage import Client + storage_client = Client() + bucket = storage_client.get_bucket(MY_BUCKET) + blobs = bucket.list_blobs(prefix=MY_PREFIX) + for blob in blobs: + blob.delete() + except ImportError: + raise ImportError('google.cloud.storage is not imported correctly.') @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), @@ -111,11 +151,11 @@ def test_clean_stale_shared_memory(): @pytest.mark.parametrize('folder_urls_pattern', [1, 2, 3, 4, 5]) -@pytest.mark.usefixtures('local_remote_dir') +@pytest.mark.parametrize('output_format', ['local', 'remote', 'tuple']) +@pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_do_merge_index(local_remote_dir: Tuple[str, str], - keep_local: bool, - folder_urls_pattern: int): +def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, folder_urls_pattern: int, + output_format: str): """Validate the final merge index json for following patterns of folder_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error @@ -124,48 +164,98 @@ def test_do_merge_index(local_remote_dir: Tuple[str, str], 5. All urls are str (remote) -> download all """ + def integrity_check(out: Union[str, Tuple[str, str]]): + """ Check if merged_index file has integrity + If merged_index is a cloud url, first download it to a temp local file. + """ + + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + + with tempfile.TemporaryDirectory() as temp_dir: + + if cu.remote: + download_file(os.path.join(cu.remote, 'index.json'), + os.path.join(temp_dir, 'index.json'), + timeout=60) + local_merged_index_path = os.path.join(temp_dir, 'index.json') + else: + local_merged_index_path = os.path.join(cu.local, 'index.json') + + if not keep_local: + assert not os.path.exists(os.path.join(cu.local, 'index.json')) + return + + assert os.path.exists(local_merged_index_path) + merged_index = json.load(open(local_merged_index_path, 'r')) + n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) + assert (n_shard_files == 2), 'expected 2 shard files but got {n_shard_files}' + + if output_format != 'local': + if not MANUAL_INTEGRATION_TEST: + pytest.skip('Require cloud credentials. ' + + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') + if output_format == 'remote': + out = manual_integration_dir()[1] + else: + out = manual_integration_dir() + else: + out = manual_integration_dir()[0] + naive_mds_partitions = [ 'tests/resources/naive_MDSdataset/25/', 'tests/resources/naive_MDSdataset/26/', 'tests/resources/naive_MDSdataset/27/' ] - if folder_urls_pattern in [4,5]: - # Require cloud file transfers. Will be covered by integration tests. - return + if folder_urls_pattern == 1: + folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] + do_merge_index(folder_urls, out, keep_local=keep_local) - with tempfile.TemporaryDirectory() as out: - if folder_urls_pattern == 1: - folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] - do_merge_index(folder_urls, out, keep_local=keep_local) - - - if folder_urls_pattern == 2: - folder_urls = [out + '/' + s for s in naive_mds_partitions] - with pytest.raises( - FileNotFoundError, - match=f'.* does not exist or not accessible.*'): + if folder_urls_pattern == 2: + with tempfile.TemporaryDirectory() as a_temporary_folder: + folder_urls = [a_temporary_folder + '/' + s for s in naive_mds_partitions] + with pytest.raises(FileNotFoundError, match=f'.* does not exist or not accessible.*'): do_merge_index(folder_urls, out, keep_local=keep_local) return - if folder_urls_pattern == 3: + if folder_urls_pattern == 3: + folder_urls = [] + for s in naive_mds_partitions: + folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) + do_merge_index(folder_urls, out, keep_local=keep_local) + + if folder_urls_pattern == 4: + if not MANUAL_INTEGRATION_TEST: + pytest.skip('Require cloud credentials. ' + + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') + + with tempfile.TemporaryDirectory() as a_temporary_folder: folder_urls = [] for s in naive_mds_partitions: - folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) + cu_path = (os.getcwd() + '/' + s, 'gs://' + MY_BUCKET + '/' + s) + cu = CloudUploader.get(cu_path, keep_local=True, exist_ok=True) + index_json = os.path.join(cu.local, 'index.json') + if os.path.exists(index_json): + cu.upload_file('index.json') + folder_urls.append((a_temporary_folder, 'gs://' + MY_BUCKET + '/' + s)) do_merge_index(folder_urls, out, keep_local=keep_local) - # Integrity checks - - merged_index_path = os.path.join(out, 'index.json') - - if not keep_local: - assert not os.path.exists(merged_index_path) - return + if folder_urls_pattern == 5: + if not MANUAL_INTEGRATION_TEST: + pytest.skip('Require cloud credentials. ' + + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - assert os.path.exists(merged_index_path) - merged_index = json.load(open(merged_index_path, 'r')) - n_shard_files = len(set([b['raw_data']['basename'] for b in merged_index['shards']])) - assert(n_shard_files == 2), "expected 2 shard files but got {n_shard_files}" + with tempfile.TemporaryDirectory() as a_temporary_folder: + folder_urls = [] + for s in naive_mds_partitions: + cu_path = (os.getcwd() + '/' + s, 'gs://' + MY_BUCKET + '/' + s) + cu = CloudUploader.get(cu_path, keep_local=True, exist_ok=True) + index_json = os.path.join(cu.local, 'index.json') + if os.path.exists(index_json): + cu.upload_file('index.json') + folder_urls.append('gs://' + MY_BUCKET + '/' + s) + do_merge_index(folder_urls, out, keep_local=keep_local) + integrity_check(out) @pytest.mark.parametrize('with_args', [True, False]) From 69857d4bc82252a09cd0b47206eb7c17cf0dd830 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 13:07:34 -0700 Subject: [PATCH 15/59] unit tests + pre-commit ok --- streaming/base/converters/dataframe_to_mds.py | 3 +-- tests/base/converters/test_dataframe_to_mds.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 7a83f9549..04067c28f 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -11,8 +11,7 @@ import pandas as pd -from streaming.base.util import get_import_exception_message -from streaming.base.util import merge_index as do_merge_index +from streaming.base.util import do_merge_index, get_import_exception_message try: from pyspark import TaskContext diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 2dacdb0a4..3672e7040 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -18,7 +18,7 @@ MY_PREFIX = 'train' MY_BUCKET = 'mosaicml-composer-tests' -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -27,7 +27,9 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = mkdtemp() From bf2d4efaa92511571b64ef6ae669e5cc9d0d7fd3 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 15:53:13 -0700 Subject: [PATCH 16/59] Add list objects for oci, gs, s3 --- streaming/base/storage/download.py | 214 ++++++++++++++++++++++++++--- streaming/base/util.py | 18 ++- tests/test_list_objects.py | 158 +++++++++++++++++++++ tests/test_util.py | 94 +++++++++---- 4 files changed, 429 insertions(+), 55 deletions(-) create mode 100644 tests/test_list_objects.py diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index f10d2020f..65c0d89af 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -490,18 +490,195 @@ def wait_for_download(local: str, timeout: float = 60) -> None: f'Waited longer than {timeout}s for other worker to download {local}.') sleep(0.25) +def remove_prefix(obj: str): + """Remove prefix from ob -def list_objects_from_s3(remote: str) -> List[str]: - return [] + Args: + obj (str): take form of 'path/to/folder' + return: + (str): take form of 'to/folder' + """ + return "/".join(obj.strip("/").split('/')[1:]) + +def list_objects_from_s3(remote: str, timeout: int = 60) -> List[str]: + """List objects from remote AWS S3. + + Args: + remote (str): Remote path (S3). + """ + import boto3 + from boto3.s3.transfer import TransferConfig + from botocore import UNSIGNED + from botocore.config import Config + from botocore.exceptions import ClientError, NoCredentialsError + + def _list_objects(obj, unsigned: bool = False) -> None: + """List the objects from AWS S3 bucket. The bucket can be either public or private. + + Args: + unsigned (bool, optional): Set to True if it is a public bucket. + Defaults to ``False``. + """ + if unsigned: + # Client will be using unsigned mode in which public + # resources can be accessed without credentials + config = Config(read_timeout=timeout, signature_version=UNSIGNED) + else: + config = Config(read_timeout=timeout) + + # Create a new session per thread + session = boto3.session.Session() + # Create a resource client using a thread's session object + s3 = session.client('s3', config=config, endpoint_url=os.environ.get('S3_ENDPOINT_URL')) + # Threads calling S3 operations return RuntimeError (cannot schedule new futures after + # interpreter shutdown). Temporary solution is to have `use_threads` as `False`. + # Issue: https://github.com/boto/boto3/issues/3113 + paginator = s3.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=obj.netloc, Prefix=obj.path) + ans = [] + for page in pages: + if 'Contents' in page: + for o in page['Contents']: + ans.append(remove_prefix(o['Key'])) + return ans + + + obj = urllib.parse.urlparse(remote) + if obj.scheme != 's3': + raise ValueError( + f'Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote}') + + try: + return _list_objects(obj) + except NoCredentialsError: + # Public S3 buckets without credentials + return _list_objects(obj, unsigned=True) + except ClientError as e: + if e.response['Error']['Code'] in BOTOCORE_CLIENT_ERROR_CODES: + e.args = (f'Object {remote} not found! Either check the bucket path or the bucket ' + + f'permission. If the bucket is a requester pays bucket, then provide the ' + + f'bucket name to the environment variable ' + + f'`MOSAICML_STREAMING_AWS_REQUESTER_PAYS`.',) + raise e + elif e.response['Error']['Code'] == '400': + # Public S3 buckets without credentials + return _list_objects(obj, unsigned=True) + except Exception: + raise + + +def list_objects_from_gcs(remote: str, timeout: int = 60) -> List[str]: + """List objects from remote Google Cloud Bucket. + + Args: + remote (str): Remote path (S3). + """ + from google.auth.exceptions import DefaultCredentialsError + def _gcs_with_hmac(obj: urllib.parse.ParseResult) -> None: + """Return a list of objects from remote GCS using user level credentials. -def list_objects_from_gcs(remote: str) -> List[str]: - return [] + Args: + obj (ParseResult): ParseResult object of remote. + """ + import boto3 + from boto3.s3.transfer import TransferConfig + from botocore.exceptions import ClientError + + # Create a new session per thread + session = boto3.session.Session() + # Create a resource client using a thread's session object + gcs_client = session.client('s3', + region_name='auto', + endpoint_url='https://storage.googleapis.com', + aws_access_key_id=os.environ['GCS_KEY'], + aws_secret_access_key=os.environ['GCS_SECRET']) + try: + response = gcs_client.list_objects_v2(Bucket=obj.netloc, + Prefix=obj.path.lstrip("/")) + if response and 'Contents' in response: + return [remove_prefix(ob['Key']) for ob in response['Contents']] + + except ClientError as e: + if e.response['Error']['Code'] in BOTOCORE_CLIENT_ERROR_CODES: + raise FileNotFoundError(f'Object {obj.sheme}, {obj.netloc}, {obj.path} not found.') from e + except Exception: + raise + + def _gcs_with_service_account(obj: urllib.parse.ParseResult) -> None: + """Return a list of objects from remote GCS using service account credentials. + + Args: + obj (ParseResult): ParseResult object of remote path (GCS). + """ + from google.auth import default as default_auth + from google.cloud.storage import Blob, Bucket, Client + + credentials, _ = default_auth() + gcs_client = Client(credentials=credentials) + bucket = gcs_client.get_bucket(obj.netloc, timeout=60.0) + objects = bucket.list_blobs(prefix=obj.path.lstrip('/')) + return [remove_prefix(ob.name) for ob in objects] + + obj = urllib.parse.urlparse(remote) + if obj.scheme != 'gs': + raise ValueError( + f'Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote}') + + if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: + return _gcs_with_hmac(obj) + else: + try: + return _gcs_with_service_account(obj) + except (DefaultCredentialsError, EnvironmentError): + raise ValueError(GCS_ERROR_NO_AUTHENTICATION) def list_objects_from_oci(remote: str) -> List[str]: - return [] + """List objects from remote OCI to local. + + Args: + remote (str): Remote path (OCI). + """ + import oci + config = oci.config.from_file() + client = oci.object_storage.ObjectStorageClient( + config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY) + + obj = urllib.parse.urlparse(remote) + if obj.scheme != 'oci': + raise ValueError( + f'Expected obj.scheme to be `oci`, instead, got {obj.scheme} for remote={remote}') + + object_names = [] + next_start_with = None + response_complete = False + + while not response_complete: + response = client.list_objects(namespace_name=client.get_namespace().data, + bucket_name=obj.netloc.split('@' + namespace)[0], + prefix=obj.path.strip('/'), + start=next_start_with).data + object_names.extend([obj.name for obj in response.objects]) + next_start_with = response.next_start_with + if not next_start_with: + response_complete = True + + return object_names + + +def list_objects_from_local(path: Optional[str]) -> List[str]: + """List objects from a local directory. List current directory if path is None. Return path if path is not a directory. + + Args: + path (str): absolute path or None. + """ + if not path: + return os.listdir() + if os.path.isdir(path): + return os.listdir(path) + return path def list_objects(remote: Optional[str]) -> List[str]: """Use the correct cloud handler to list objects. @@ -510,27 +687,22 @@ def list_objects(remote: Optional[str]) -> List[str]: remote (str, optional): Remote path (local filesystem). If remote is None or '', list current working directory with os.listdir() """ + if not remote: # '' or None + return list_objects_from_local(remote) + # fix paths for windows if remote: remote = remote.replace('\\', '/') - if not remote: # '' or None - return os.listdir() - elif remote.startswith('s3://'): + obj = urllib.parse.urlparse(remote) + + if obj.scheme == '': + return list_objects_from_local(remote) + elif obj.scheme == 's3': return list_objects_from_s3(remote) - elif remote.startswith('gs://'): + elif remote.startswith('gs'): return list_objects_from_gcs(remote) - elif remote.startswith('oci://'): + elif remote.startswith('oci'): return list_objects_from_oci(remote) - elif remote.startswith('sftp://'): - raise NotImplemented('list_objects for sftp not supported') - elif remote.startswith('azure://'): - raise NotImplemented('list_objects for azure not supported') - elif remote.startswith('azure-dl://'): - raise NotImplemented('list_objects for azure-dl not supported') - elif remote.startswith('dbfs:/Volumes'): - raise NotImplemented('list_objects for dbfs:/Volumes not supported') - elif remote.startswith('dbfs:/'): - raise NotImplemented('list_objects for dbfs:/ not supported') else: - raise ValueError('remote scheme is not recognizable') + raise NotImplementedError diff --git a/streaming/base/util.py b/streaming/base/util.py index dcf1dfe67..086fbe4fb 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -241,6 +241,7 @@ def do_merge_index(folder_urls: List[Any], logger.warning('Need to specify both folder_urls and out. No index merged') return + print('folder_urls = ', folder_urls) # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() @@ -289,10 +290,10 @@ def do_merge_index(folder_urls: List[Any], local_path = os.path.join(local, index_basename) download_file(remote_url, local_path, download_timeout) except Exception as ex: - raise RuntimeError(f'Failed to download index.json {url}') from ex + raise RuntimeError(f'Failed to download index.json {remote_url}') from ex if not (os.path.exists(local)): - raise FileNotFoundError('Folder {local} does not exist or not accessible.') + raise FileNotFoundError(f'Folder {local} does not exist or not accessible.') partitions.append(local) # merge index files into shards @@ -344,17 +345,24 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ from streaming.base.storage.download import list_objects + from streaming.base.storage.upload import CloudUploader if not root_to_mds: logger.warning('No MDS dataset folder specified, no index merged') return + cu = CloudUploader.get(root_to_mds, exist_ok=True, keep_local=True) if isinstance(root_to_mds, tuple): - local_folders = list_objects(root_to_mds[0]) - remote_folders = list_objects(root_to_mds[1]) + local_folders = [os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds[0])] + remote_folders = [os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds[1])] folder_urls = list(zip(local_folders, remote_folders)) else: - folder_urls = list_objects(root_to_mds) + print('I am here 3.1', root_to_mds) + print('I am here 3', list_objects(root_to_mds)) + if cu.remote: + folder_urls = [os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds)] + else: + folder_urls = [os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds)] do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) diff --git a/tests/test_list_objects.py b/tests/test_list_objects.py new file mode 100644 index 000000000..8c6cc8b7e --- /dev/null +++ b/tests/test_list_objects.py @@ -0,0 +1,158 @@ +# Copyright 2023 MosaicML Streaming authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from typing import Any, Tuple +from unittest.mock import Mock, patch + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from streaming.base.storage.download import (list_objects, list_objects_from_gcs, + list_objects_from_local, list_objects_from_s3) +from tests.conftest import GCS_URL, MY_BUCKET, R2_URL + +MY_PREFIX = 'train' + + +@pytest.fixture(scope='function') +def remote_local_file() -> Any: + """Creates a temporary directory and then deletes it when the calling function is done.""" + + def _method(cloud_prefix: str = '', filename: str = 'file.txt') -> Tuple[str, str]: + try: + mock_local_dir = tempfile.TemporaryDirectory() + mock_local_filepath = os.path.join(mock_local_dir.name, filename) + mock_remote_filepath = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX, filename) + return mock_remote_filepath, mock_local_filepath + finally: + mock_local_dir.cleanup() # pyright: ignore + + return _method + + +class TestS3Client: + + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') + def test_list_objects_from_s3(self, remote_local_file: Any): + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: + file_name = tmp.name.split(os.sep)[-1] + mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://', filename=file_name) + client = boto3.client('s3', region_name='us-east-1') + client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') + objs = list_objects_from_s3(mock_remote_filepath) + assert isinstance(objs, list) + + @pytest.mark.usefixtures('s3_client', 's3_test', 'r2_credentials', 'remote_local_file') + def test_list_objects_from_s3_with_endpoint_URL(self, remote_local_file: Any): + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: + file_name = tmp.name.split(os.sep)[-1] + mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://', filename=file_name) + client = boto3.client('s3', region_name='us-east-1') + client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') + objs = list_objects_from_s3(mock_remote_filepath) + assert os.environ['S3_ENDPOINT_URL'] == R2_URL + + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') + def test_clienterror_exception(self, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') + objs = list_objects_from_s3(mock_remote_filepath) + assert(len(objs) == 0) + + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') + def test_invalid_cloud_prefix(self, remote_local_file: Any): + with pytest.raises(ValueError): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s9://') + objs = list_objects_from_s3(mock_remote_filepath) + + +class TestGCSClient: + + @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') + def test_list_objects_from_gcs(self, remote_local_file: Any): + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: + file_name = tmp.name.split(os.sep)[-1] + mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://', filename=file_name) + client = boto3.client('s3', + region_name='us-east-1', + endpoint_url=GCS_URL, + aws_access_key_id=os.environ['GCS_KEY'], + aws_secret_access_key=os.environ['GCS_SECRET']) + client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') + objs = list_objects_from_gcs(mock_remote_filepath) + assert isinstance(objs, list) + + @patch('google.auth.default') + @patch('google.cloud.storage.Client') + @pytest.mark.usefixtures('gcs_service_account_credentials') + @pytest.mark.parametrize('out', ['gs://bucket/dir']) + def test_download_service_account(self, mock_client: Mock, mock_default: Mock, out: str): + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: + credentials_mock = Mock() + mock_default.return_value = credentials_mock, None + objs = list_objects_from_gcs(out) + mock_client.assert_called_once_with(credentials=credentials_mock) + assert isinstance(objs, list) + + @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') + def test_filenotfound_exception(self, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') + objs = list_objects_from_gcs(mock_remote_filepath) + + @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') + def test_invalid_cloud_prefix(self, remote_local_file: Any): + with pytest.raises(ValueError): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') + objs = list_objects_from_gcs(mock_remote_filepath) + + def test_no_credentials_error(self, remote_local_file: Any): + """Ensure we raise a value error correctly if we have no credentials available.""" + with pytest.raises(ValueError): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') + objs = list_objects_from_gcs(mock_remote_filepath) + + +def test_list_objects_from_local(): + mock_local_dir = tempfile.TemporaryDirectory() + file_name = 'file.txt' + mock_local_file = os.path.join(mock_local_dir.name, file_name) + # Creates a new empty file + with open(mock_local_file, 'w') as _: + pass + + objs = list_objects_from_local(mock_local_file) + assert objs == mock_local_file + + +class TestListObjects: + + @patch('streaming.base.storage.download.list_objects_from_s3') + @pytest.mark.usefixtures('remote_local_file') + def test_list_objects_from_s3_gets_called(self, mocked_requests: Mock, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') + objs = list_objects(mock_remote_filepath) + mocked_requests.assert_called_once() + mocked_requests.assert_called_once_with(mock_remote_filepath) + + @patch('streaming.base.storage.download.list_objects_from_gcs') + @pytest.mark.usefixtures('remote_local_file') + def test_list_objects_from_gcs_gets_called(self, mocked_requests: Mock, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') + objs = list_objects(mock_remote_filepath) + mocked_requests.assert_called_once() + mocked_requests.assert_called_once_with(mock_remote_filepath) + + @patch('streaming.base.storage.download.list_objects_from_local') + @pytest.mark.usefixtures('remote_local_file') + def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, remote_local_file: Any): + mock_remote_filepath, mock_local_filepath = remote_local_file() + objs = list_objects(mock_remote_filepath) + mocked_requests.assert_called_once() + mocked_requests.assert_called_once_with(mock_remote_filepath) + + @pytest.mark.usefixtures('remote_local_file') + def test_list_objects_invalid_missing_remote(self): + obj = list_objects(None) + assert(obj == os.listdir()) diff --git a/tests/test_util.py b/tests/test_util.py index cf8cf002b..f5dd7b8a2 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,9 +12,9 @@ from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path -from streaming.base.storage.download import download_file +from streaming.base.storage.download import download_file, list_objects from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, +from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, merge_index, get_list_arg, number_abbrev_to_int, retry) MY_PREFIX = 'train' @@ -149,6 +149,68 @@ def test_clean_stale_shared_memory(): with pytest.raises(FileNotFoundError): _ = BuiltinSharedMemory(name, False, 64) +def integrity_check(out: Union[str, Tuple[str, str]], keep_local): + """ Check if merged_index file has integrity + If merged_index is a cloud url, first download it to a temp local file. + """ + + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + + with tempfile.TemporaryDirectory() as temp_dir: + + if cu.remote: + download_file(os.path.join(cu.remote, 'index.json'), + os.path.join(temp_dir, 'index.json'), + timeout=60) + local_merged_index_path = os.path.join(temp_dir, 'index.json') + else: + local_merged_index_path = os.path.join(cu.local, 'index.json') + + if not keep_local: + assert not os.path.exists(os.path.join(cu.local, 'index.json')) + return + + assert os.path.exists(local_merged_index_path) + merged_index = json.load(open(local_merged_index_path, 'r')) + n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) + assert (n_shard_files == 2), 'expected 2 shard files but got {n_shard_files}' + +def test_merge_index(manual_integration_dir: Any): + from decimal import Decimal + from streaming.base.converters import dataframeToMDS + from pyspark.sql import SparkSession + from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + + spark = SparkSession.builder.getOrCreate() # pyright: ignore + schema = StructType([ + StructField('id', IntegerType(), nullable=False), + StructField('name', StringType(), nullable=False), + StructField('amount', DecimalType(10, 2), nullable=False) + ]) + + data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), + (3, 'Charlie', Decimal('987.65'))] + + df = spark.createDataFrame(data=data, schema=schema).repartition(3) + + _, remote = manual_integration_dir() + mds_kwargs = { + 'out': remote, + 'columns': { + 'id': 'int', + 'name': 'str' + }, + } + print('I am here 0: remote = ', remote) + + mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + + print('mds_path = ', mds_path) + print(list_objects("gs://mosaicml-composer-tests/train/")) + merge_index(remote) + + integrity_check(remote, keep_local=True) + @pytest.mark.parametrize('folder_urls_pattern', [1, 2, 3, 4, 5]) @pytest.mark.parametrize('output_format', ['local', 'remote', 'tuple']) @@ -164,32 +226,6 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, folder_ur 5. All urls are str (remote) -> download all """ - def integrity_check(out: Union[str, Tuple[str, str]]): - """ Check if merged_index file has integrity - If merged_index is a cloud url, first download it to a temp local file. - """ - - cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - - with tempfile.TemporaryDirectory() as temp_dir: - - if cu.remote: - download_file(os.path.join(cu.remote, 'index.json'), - os.path.join(temp_dir, 'index.json'), - timeout=60) - local_merged_index_path = os.path.join(temp_dir, 'index.json') - else: - local_merged_index_path = os.path.join(cu.local, 'index.json') - - if not keep_local: - assert not os.path.exists(os.path.join(cu.local, 'index.json')) - return - - assert os.path.exists(local_merged_index_path) - merged_index = json.load(open(local_merged_index_path, 'r')) - n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) - assert (n_shard_files == 2), 'expected 2 shard files but got {n_shard_files}' - if output_format != 'local': if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + @@ -255,7 +291,7 @@ def integrity_check(out: Union[str, Tuple[str, str]]): folder_urls.append('gs://' + MY_BUCKET + '/' + s) do_merge_index(folder_urls, out, keep_local=keep_local) - integrity_check(out) + integrity_check(out, keep_local=keep_local) @pytest.mark.parametrize('with_args', [True, False]) From cf0fe95c9a99b8db6131fe73b9fc717e1ae374c0 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 16:48:31 -0700 Subject: [PATCH 17/59] fix tests --- streaming/base/storage/download.py | 83 ++++++++++++++++++------------ streaming/base/util.py | 18 +++++-- tests/test_list_objects.py | 51 +++++++++--------- tests/test_util.py | 18 +++++-- 4 files changed, 102 insertions(+), 68 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 65c0d89af..a09362e61 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -490,32 +490,35 @@ def wait_for_download(local: str, timeout: float = 60) -> None: f'Waited longer than {timeout}s for other worker to download {local}.') sleep(0.25) + def remove_prefix(obj: str): - """Remove prefix from ob + """Remove prefix from ob. Args: obj (str): take form of 'path/to/folder' return: (str): take form of 'to/folder' """ - return "/".join(obj.strip("/").split('/')[1:]) + return '/'.join(obj.strip('/').split('/')[1:]) + -def list_objects_from_s3(remote: str, timeout: int = 60) -> List[str]: +def list_objects_from_s3(remote: str, timeout: float = 60) -> Optional[List[str]]: """List objects from remote AWS S3. Args: remote (str): Remote path (S3). + timeout (float): How long to wait for objects to be returned. """ import boto3 - from boto3.s3.transfer import TransferConfig from botocore import UNSIGNED from botocore.config import Config from botocore.exceptions import ClientError, NoCredentialsError - def _list_objects(obj, unsigned: bool = False) -> None: + def _list_objects(obj: urllib.parse.ParseResult, unsigned: bool = False) -> List[str]: """List the objects from AWS S3 bucket. The bucket can be either public or private. Args: + obj (ParseResult): ParseResult object of remote. unsigned (bool, optional): Set to True if it is a public bucket. Defaults to ``False``. """ @@ -542,7 +545,6 @@ def _list_objects(obj, unsigned: bool = False) -> None: ans.append(remove_prefix(o['Key'])) return ans - obj = urllib.parse.urlparse(remote) if obj.scheme != 's3': raise ValueError( @@ -567,22 +569,22 @@ def _list_objects(obj, unsigned: bool = False) -> None: raise -def list_objects_from_gcs(remote: str, timeout: int = 60) -> List[str]: +def list_objects_from_gcs(remote: str, timeout: float = 60) -> Optional[List[str]]: """List objects from remote Google Cloud Bucket. Args: remote (str): Remote path (S3). + timeout (float): How long to wait for objects to be returned. """ from google.auth.exceptions import DefaultCredentialsError - def _gcs_with_hmac(obj: urllib.parse.ParseResult) -> None: + def _gcs_with_hmac(obj: urllib.parse.ParseResult) -> Optional[List[str]]: """Return a list of objects from remote GCS using user level credentials. Args: obj (ParseResult): ParseResult object of remote. """ import boto3 - from boto3.s3.transfer import TransferConfig from botocore.exceptions import ClientError # Create a new session per thread @@ -594,31 +596,34 @@ def _gcs_with_hmac(obj: urllib.parse.ParseResult) -> None: aws_access_key_id=os.environ['GCS_KEY'], aws_secret_access_key=os.environ['GCS_SECRET']) try: - response = gcs_client.list_objects_v2(Bucket=obj.netloc, - Prefix=obj.path.lstrip("/")) + response = gcs_client.list_objects_v2(Bucket=obj.netloc, Prefix=obj.path.lstrip('/')) if response and 'Contents' in response: return [remove_prefix(ob['Key']) for ob in response['Contents']] except ClientError as e: if e.response['Error']['Code'] in BOTOCORE_CLIENT_ERROR_CODES: - raise FileNotFoundError(f'Object {obj.sheme}, {obj.netloc}, {obj.path} not found.') from e + raise FileNotFoundError( + f'Object {obj.scheme}, {obj.netloc}, {obj.path} not found.') from e except Exception: raise - def _gcs_with_service_account(obj: urllib.parse.ParseResult) -> None: + def _gcs_with_service_account(obj: urllib.parse.ParseResult) -> Optional[List[str]]: """Return a list of objects from remote GCS using service account credentials. Args: obj (ParseResult): ParseResult object of remote path (GCS). """ from google.auth import default as default_auth - from google.cloud.storage import Blob, Bucket, Client + from google.cloud.storage import Client credentials, _ = default_auth() gcs_client = Client(credentials=credentials) bucket = gcs_client.get_bucket(obj.netloc, timeout=60.0) objects = bucket.list_blobs(prefix=obj.path.lstrip('/')) - return [remove_prefix(ob.name) for ob in objects] + ans = [] + for ob in objects: + ans.append(remove_prefix(ob.name)) + return ans obj = urllib.parse.urlparse(remote) if obj.scheme != 'gs': @@ -626,7 +631,10 @@ def _gcs_with_service_account(obj: urllib.parse.ParseResult) -> None: f'Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote}') if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: - return _gcs_with_hmac(obj) + try: + return _gcs_with_hmac(obj) + except (DefaultCredentialsError, EnvironmentError): + raise ValueError(GCS_ERROR_NO_AUTHENTICATION) else: try: return _gcs_with_service_account(obj) @@ -634,7 +642,7 @@ def _gcs_with_service_account(obj: urllib.parse.ParseResult) -> None: raise ValueError(GCS_ERROR_NO_AUTHENTICATION) -def list_objects_from_oci(remote: str) -> List[str]: +def list_objects_from_oci(remote: str) -> Optional[List[str]]: """List objects from remote OCI to local. Args: @@ -650,16 +658,16 @@ def list_objects_from_oci(remote: str) -> List[str]: raise ValueError( f'Expected obj.scheme to be `oci`, instead, got {obj.scheme} for remote={remote}') - object_names = [] + object_names = [''] next_start_with = None response_complete = False - + namespace = client.get_namespace().data while not response_complete: - response = client.list_objects(namespace_name=client.get_namespace().data, + response = client.list_objects(namespace_name=namespace, bucket_name=obj.netloc.split('@' + namespace)[0], prefix=obj.path.strip('/'), start=next_start_with).data - object_names.extend([obj.name for obj in response.objects]) + object_names.extend([remove_prefix(obj.name) for obj in response.objects]) next_start_with = response.next_start_with if not next_start_with: response_complete = True @@ -667,18 +675,20 @@ def list_objects_from_oci(remote: str) -> List[str]: return object_names -def list_objects_from_local(path: Optional[str]) -> List[str]: - """List objects from a local directory. List current directory if path is None. Return path if path is not a directory. +def list_objects_from_local(path: Optional[str]) -> Optional[List[str]]: + """List objects from a local directory. Args: path (str): absolute path or None. + + Notes: + List current directory if path is None. + Raise error if path is a file """ if not path: return os.listdir() - if os.path.isdir(path): - return os.listdir(path) + return os.listdir(path) - return path def list_objects(remote: Optional[str]) -> List[str]: """Use the correct cloud handler to list objects. @@ -688,7 +698,10 @@ def list_objects(remote: Optional[str]) -> List[str]: If remote is None or '', list current working directory with os.listdir() """ if not remote: # '' or None - return list_objects_from_local(remote) + ans = list_objects_from_local(remote) + if not ans: + return [''] + return ans # fix paths for windows if remote: @@ -697,12 +710,16 @@ def list_objects(remote: Optional[str]) -> List[str]: obj = urllib.parse.urlparse(remote) if obj.scheme == '': - return list_objects_from_local(remote) + ans = list_objects_from_local(remote) elif obj.scheme == 's3': - return list_objects_from_s3(remote) - elif remote.startswith('gs'): - return list_objects_from_gcs(remote) - elif remote.startswith('oci'): - return list_objects_from_oci(remote) + ans = list_objects_from_s3(remote) + elif obj.scheme == 'gs': + ans = list_objects_from_gcs(remote) + elif obj.scheme == 'oci': + ans = list_objects_from_oci(remote) else: raise NotImplementedError + + if not ans: + return [''] + return ans diff --git a/streaming/base/util.py b/streaming/base/util.py index 086fbe4fb..cbc3c7728 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -290,7 +290,7 @@ def do_merge_index(folder_urls: List[Any], local_path = os.path.join(local, index_basename) download_file(remote_url, local_path, download_timeout) except Exception as ex: - raise RuntimeError(f'Failed to download index.json {remote_url}') from ex + raise RuntimeError(f'Failed to download index.json') from ex if not (os.path.exists(local)): raise FileNotFoundError(f'Folder {local} does not exist or not accessible.') @@ -353,16 +353,24 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], cu = CloudUploader.get(root_to_mds, exist_ok=True, keep_local=True) if isinstance(root_to_mds, tuple): - local_folders = [os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds[0])] - remote_folders = [os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds[1])] + local_folders = [ + os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds[0]) + ] + remote_folders = [ + os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds[1]) + ] folder_urls = list(zip(local_folders, remote_folders)) else: print('I am here 3.1', root_to_mds) print('I am here 3', list_objects(root_to_mds)) if cu.remote: - folder_urls = [os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds)] + folder_urls = [ + os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds) + ] else: - folder_urls = [os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds)] + folder_urls = [ + os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds) + ] do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) diff --git a/tests/test_list_objects.py b/tests/test_list_objects.py index 8c6cc8b7e..90f22efc7 100644 --- a/tests/test_list_objects.py +++ b/tests/test_list_objects.py @@ -8,7 +8,6 @@ import boto3 import pytest -from botocore.exceptions import ClientError from streaming.base.storage.download import (list_objects, list_objects_from_gcs, list_objects_from_local, list_objects_from_s3) @@ -52,20 +51,21 @@ def test_list_objects_from_s3_with_endpoint_URL(self, remote_local_file: Any): mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://', filename=file_name) client = boto3.client('s3', region_name='us-east-1') client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') - objs = list_objects_from_s3(mock_remote_filepath) + _ = list_objects_from_s3(mock_remote_filepath) assert os.environ['S3_ENDPOINT_URL'] == R2_URL @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') def test_clienterror_exception(self, remote_local_file: Any): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') + mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://') objs = list_objects_from_s3(mock_remote_filepath) - assert(len(objs) == 0) + if objs: + assert (len(objs) == 0) @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') def test_invalid_cloud_prefix(self, remote_local_file: Any): with pytest.raises(ValueError): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s9://') - objs = list_objects_from_s3(mock_remote_filepath) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='s9://') + _ = list_objects_from_s3(mock_remote_filepath) class TestGCSClient: @@ -89,7 +89,7 @@ def test_list_objects_from_gcs(self, remote_local_file: Any): @pytest.mark.usefixtures('gcs_service_account_credentials') @pytest.mark.parametrize('out', ['gs://bucket/dir']) def test_download_service_account(self, mock_client: Mock, mock_default: Mock, out: str): - with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as _: credentials_mock = Mock() mock_default.return_value = credentials_mock, None objs = list_objects_from_gcs(out) @@ -98,20 +98,20 @@ def test_download_service_account(self, mock_client: Mock, mock_default: Mock, o @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') def test_filenotfound_exception(self, remote_local_file: Any): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') - objs = list_objects_from_gcs(mock_remote_filepath) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://') + _ = list_objects_from_gcs(mock_remote_filepath) @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') def test_invalid_cloud_prefix(self, remote_local_file: Any): with pytest.raises(ValueError): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') - objs = list_objects_from_gcs(mock_remote_filepath) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://') + _ = list_objects_from_gcs(mock_remote_filepath) def test_no_credentials_error(self, remote_local_file: Any): """Ensure we raise a value error correctly if we have no credentials available.""" with pytest.raises(ValueError): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') - objs = list_objects_from_gcs(mock_remote_filepath) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://') + _ = list_objects_from_gcs(mock_remote_filepath) def test_list_objects_from_local(): @@ -121,9 +121,8 @@ def test_list_objects_from_local(): # Creates a new empty file with open(mock_local_file, 'w') as _: pass - - objs = list_objects_from_local(mock_local_file) - assert objs == mock_local_file + with pytest.raises(NotADirectoryError): + objs = list_objects_from_local(mock_local_file) class TestListObjects: @@ -131,28 +130,30 @@ class TestListObjects: @patch('streaming.base.storage.download.list_objects_from_s3') @pytest.mark.usefixtures('remote_local_file') def test_list_objects_from_s3_gets_called(self, mocked_requests: Mock, remote_local_file: Any): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='s3://') - objs = list_objects(mock_remote_filepath) + mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://') + list_objects(mock_remote_filepath) mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath) @patch('streaming.base.storage.download.list_objects_from_gcs') @pytest.mark.usefixtures('remote_local_file') - def test_list_objects_from_gcs_gets_called(self, mocked_requests: Mock, remote_local_file: Any): - mock_remote_filepath, mock_local_filepath = remote_local_file(cloud_prefix='gs://') - objs = list_objects(mock_remote_filepath) + def test_list_objects_from_gcs_gets_called(self, mocked_requests: Mock, + remote_local_file: Any): + mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://') + list_objects(mock_remote_filepath) mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath) @patch('streaming.base.storage.download.list_objects_from_local') @pytest.mark.usefixtures('remote_local_file') - def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, remote_local_file: Any): - mock_remote_filepath, mock_local_filepath = remote_local_file() - objs = list_objects(mock_remote_filepath) + def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, + remote_local_file: Any): + mock_remote_filepath, _ = remote_local_file() + list_objects(mock_remote_filepath) mocked_requests.assert_called_once() mocked_requests.assert_called_once_with(mock_remote_filepath) @pytest.mark.usefixtures('remote_local_file') def test_list_objects_invalid_missing_remote(self): obj = list_objects(None) - assert(obj == os.listdir()) + assert (obj == os.listdir()) diff --git a/tests/test_util.py b/tests/test_util.py index f5dd7b8a2..2996b7c49 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -14,8 +14,8 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file, list_objects from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, merge_index, - get_list_arg, number_abbrev_to_int, retry) +from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, + get_list_arg, merge_index, number_abbrev_to_int, retry) MY_PREFIX = 'train' MY_BUCKET = 'mosaicml-composer-tests' @@ -149,9 +149,14 @@ def test_clean_stale_shared_memory(): with pytest.raises(FileNotFoundError): _ = BuiltinSharedMemory(name, False, 64) -def integrity_check(out: Union[str, Tuple[str, str]], keep_local): + +def integrity_check(out: Union[str, Tuple[str, str]], keep_local: bool): """ Check if merged_index file has integrity If merged_index is a cloud url, first download it to a temp local file. + + Args: + out (Union[str, Tuple[str,str]]): folder that merged index.json resides + keep_local: whether to check local file """ cu = CloudUploader.get(out, keep_local=True, exist_ok=True) @@ -175,12 +180,15 @@ def integrity_check(out: Union[str, Tuple[str, str]], keep_local): n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) assert (n_shard_files == 2), 'expected 2 shard files but got {n_shard_files}' + def test_merge_index(manual_integration_dir: Any): from decimal import Decimal - from streaming.base.converters import dataframeToMDS + from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + from streaming.base.converters import dataframeToMDS + spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ StructField('id', IntegerType(), nullable=False), @@ -206,7 +214,7 @@ def test_merge_index(manual_integration_dir: Any): mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) print('mds_path = ', mds_path) - print(list_objects("gs://mosaicml-composer-tests/train/")) + print(list_objects('gs://mosaicml-composer-tests/train/')) merge_index(remote) integrity_check(remote, keep_local=True) From 5e0a1b3c8316b7af3045cb86108293f8676abd36 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 16:49:22 -0700 Subject: [PATCH 18/59] Fix lints --- tests/test_list_objects.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_list_objects.py b/tests/test_list_objects.py index 90f22efc7..44a14a17f 100644 --- a/tests/test_list_objects.py +++ b/tests/test_list_objects.py @@ -122,7 +122,7 @@ def test_list_objects_from_local(): with open(mock_local_file, 'w') as _: pass with pytest.raises(NotADirectoryError): - objs = list_objects_from_local(mock_local_file) + _ = list_objects_from_local(mock_local_file) class TestListObjects: From cbcb3526843b337834b0eda6a3f5be7a5c12f76e Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 23:20:43 -0700 Subject: [PATCH 19/59] list_objects returns only basename --- streaming/base/storage/download.py | 11 +++---- streaming/base/util.py | 8 ++--- .../base/converters/test_dataframe_to_mds.py | 7 ++-- tests/test_util.py | 32 +++++++++---------- 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index a09362e61..a4c61e00e 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -675,7 +675,7 @@ def list_objects_from_oci(remote: str) -> Optional[List[str]]: return object_names -def list_objects_from_local(path: Optional[str]) -> Optional[List[str]]: +def list_objects_from_local(path: Optional[str]) -> List[str]: """List objects from a local directory. Args: @@ -698,10 +698,7 @@ def list_objects(remote: Optional[str]) -> List[str]: If remote is None or '', list current working directory with os.listdir() """ if not remote: # '' or None - ans = list_objects_from_local(remote) - if not ans: - return [''] - return ans + return list_objects_from_local(remote) # fix paths for windows if remote: @@ -710,7 +707,7 @@ def list_objects(remote: Optional[str]) -> List[str]: obj = urllib.parse.urlparse(remote) if obj.scheme == '': - ans = list_objects_from_local(remote) + return list_objects_from_local(remote) elif obj.scheme == 's3': ans = list_objects_from_s3(remote) elif obj.scheme == 'gs': @@ -722,4 +719,4 @@ def list_objects(remote: Optional[str]) -> List[str]: if not ans: return [''] - return ans + return [ os.path.dirname(o) for o in ans] diff --git a/streaming/base/util.py b/streaming/base/util.py index cbc3c7728..f8e6d216a 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -354,10 +354,10 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], cu = CloudUploader.get(root_to_mds, exist_ok=True, keep_local=True) if isinstance(root_to_mds, tuple): local_folders = [ - os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds[0]) + os.path.join(cu.local, o) for o in list_objects(root_to_mds[0]) ] remote_folders = [ - os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds[1]) + os.path.join(cu.remote, o) for o in list_objects(root_to_mds[1]) ] folder_urls = list(zip(local_folders, remote_folders)) else: @@ -365,11 +365,11 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], print('I am here 3', list_objects(root_to_mds)) if cu.remote: folder_urls = [ - os.path.join(cu.remote, os.path.dirname(o)) for o in list_objects(root_to_mds) + os.path.join(cu.remote, o) for o in list_objects(root_to_mds) ] else: folder_urls = [ - os.path.join(cu.local, os.path.dirname(o)) for o in list_objects(root_to_mds) + os.path.join(cu.local, o) for o in list_objects(root_to_mds) ] do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 3672e7040..cd936f774 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -18,7 +18,7 @@ MY_PREFIX = 'train' MY_BUCKET = 'mosaicml-composer-tests' -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -27,9 +27,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = mkdtemp() diff --git a/tests/test_util.py b/tests/test_util.py index 2996b7c49..0c10aeaab 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -19,7 +19,7 @@ MY_PREFIX = 'train' MY_BUCKET = 'mosaicml-composer-tests' -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -28,9 +28,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = tempfile.mkdtemp() @@ -181,14 +180,21 @@ def integrity_check(out: Union[str, Tuple[str, str]], keep_local: bool): assert (n_shard_files == 2), 'expected 2 shard files but got {n_shard_files}' -def test_merge_index(manual_integration_dir: Any): +@pytest.mark.parametrize('output_format', ['local', 'remote']) +def test_merge_index(manual_integration_dir: Any, output_format: str): from decimal import Decimal - from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - from streaming.base.converters import dataframeToMDS + if output_format == 'remote': + if not MANUAL_INTEGRATION_TEST: + pytest.skip('Require cloud credentials. ' + + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') + _, out = manual_integration_dir() + else: + out, _ = manual_integration_dir() + spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ StructField('id', IntegerType(), nullable=False), @@ -201,23 +207,17 @@ def test_merge_index(manual_integration_dir: Any): df = spark.createDataFrame(data=data, schema=schema).repartition(3) - _, remote = manual_integration_dir() mds_kwargs = { - 'out': remote, + 'out': out, 'columns': { 'id': 'int', 'name': 'str' }, } - print('I am here 0: remote = ', remote) mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - - print('mds_path = ', mds_path) - print(list_objects('gs://mosaicml-composer-tests/train/')) - merge_index(remote) - - integrity_check(remote, keep_local=True) + merge_index(out) + integrity_check(out, keep_local=True) @pytest.mark.parametrize('folder_urls_pattern', [1, 2, 3, 4, 5]) From 81c3b88b24ccd605c96d61b6a66564fa2bb8350f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 29 Sep 2023 23:26:01 -0700 Subject: [PATCH 20/59] Fix lints --- streaming/base/storage/download.py | 2 +- streaming/base/util.py | 16 ++++------------ tests/test_util.py | 6 ++++-- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index a4c61e00e..14b7bdfd2 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -719,4 +719,4 @@ def list_objects(remote: Optional[str]) -> List[str]: if not ans: return [''] - return [ os.path.dirname(o) for o in ans] + return [os.path.dirname(o) for o in ans] diff --git a/streaming/base/util.py b/streaming/base/util.py index f8e6d216a..0bb6810fe 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -353,24 +353,16 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], cu = CloudUploader.get(root_to_mds, exist_ok=True, keep_local=True) if isinstance(root_to_mds, tuple): - local_folders = [ - os.path.join(cu.local, o) for o in list_objects(root_to_mds[0]) - ] - remote_folders = [ - os.path.join(cu.remote, o) for o in list_objects(root_to_mds[1]) - ] + local_folders = [os.path.join(cu.local, o) for o in list_objects(root_to_mds[0])] + remote_folders = [os.path.join(cu.remote, o) for o in list_objects(root_to_mds[1])] folder_urls = list(zip(local_folders, remote_folders)) else: print('I am here 3.1', root_to_mds) print('I am here 3', list_objects(root_to_mds)) if cu.remote: - folder_urls = [ - os.path.join(cu.remote, o) for o in list_objects(root_to_mds) - ] + folder_urls = [os.path.join(cu.remote, o) for o in list_objects(root_to_mds)] else: - folder_urls = [ - os.path.join(cu.local, o) for o in list_objects(root_to_mds) - ] + folder_urls = [os.path.join(cu.local, o) for o in list_objects(root_to_mds)] do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) diff --git a/tests/test_util.py b/tests/test_util.py index 0c10aeaab..555649261 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,7 +12,7 @@ from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path -from streaming.base.storage.download import download_file, list_objects +from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, get_list_arg, merge_index, number_abbrev_to_int, retry) @@ -183,8 +183,10 @@ def integrity_check(out: Union[str, Tuple[str, str]], keep_local: bool): @pytest.mark.parametrize('output_format', ['local', 'remote']) def test_merge_index(manual_integration_dir: Any, output_format: str): from decimal import Decimal + from pyspark.sql import SparkSession from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + from streaming.base.converters import dataframeToMDS if output_format == 'remote': @@ -215,7 +217,7 @@ def test_merge_index(manual_integration_dir: Any, output_format: str): }, } - mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) merge_index(out) integrity_check(out, keep_local=True) From a1185d3a471e5ff33319386d725c8a97370c76f7 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sat, 30 Sep 2023 01:16:53 -0700 Subject: [PATCH 21/59] fix bugs in list_objects --- streaming/base/storage/download.py | 14 +++++++++-- streaming/base/util.py | 34 ++++++++++++++++++++------ tests/test_util.py | 39 +++++++++++++++++++++++------- 3 files changed, 69 insertions(+), 18 deletions(-) diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 14b7bdfd2..74c895ea3 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -499,7 +499,7 @@ def remove_prefix(obj: str): return: (str): take form of 'to/folder' """ - return '/'.join(obj.strip('/').split('/')[1:]) + return obj # '/'.join(obj.strip('/').split('/')[1:]) def list_objects_from_s3(remote: str, timeout: float = 60) -> Optional[List[str]]: @@ -719,4 +719,14 @@ def list_objects(remote: Optional[str]) -> List[str]: if not ans: return [''] - return [os.path.dirname(o) for o in ans] + level_one_list = [] + for o in ans: + print('I am here 5', o) + suffix = o[len(obj.path):] + print('I am here 5.1', suffix, obj.path) + if '/' in suffix.strip('/'): + level_one_list.append(os.path.dirname(suffix)) + print('I am here 5.2', os.path.dirname(suffix)) + else: + level_one_list.append(suffix) + return list(set(level_one_list)) diff --git a/streaming/base/util.py b/streaming/base/util.py index 0bb6810fe..cde5df4a0 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -285,12 +285,12 @@ def do_merge_index(folder_urls: List[Any], # If download is needed, download url from remote to temp_root path = urllib.parse.urlparse(remote).path local = os.path.join(temp_root, path.lstrip('/')) + remote_url = os.path.join(remote, index_basename) + local_path = os.path.join(local, index_basename) try: - remote_url = os.path.join(remote, index_basename) - local_path = os.path.join(local, index_basename) download_file(remote_url, local_path, download_timeout) except Exception as ex: - raise RuntimeError(f'Failed to download index.json') from ex + raise RuntimeError(f'Failed to download index.json: {remote_url}') from ex if not (os.path.exists(local)): raise FileNotFoundError(f'Folder {local} does not exist or not accessible.') @@ -353,16 +353,36 @@ def merge_index(root_to_mds: Union[str, Tuple[str, str]], cu = CloudUploader.get(root_to_mds, exist_ok=True, keep_local=True) if isinstance(root_to_mds, tuple): - local_folders = [os.path.join(cu.local, o) for o in list_objects(root_to_mds[0])] - remote_folders = [os.path.join(cu.remote, o) for o in list_objects(root_to_mds[1])] + local_folders = [ + os.path.join(cu.local, o) + for o in list_objects(root_to_mds[0]) + if not o.endswith('.json') + ] + remote_folders = [ + os.path.join(cu.remote, o) + for o in list_objects(root_to_mds[1]) + if not o.endswith('.json') + ] folder_urls = list(zip(local_folders, remote_folders)) else: print('I am here 3.1', root_to_mds) print('I am here 3', list_objects(root_to_mds)) if cu.remote: - folder_urls = [os.path.join(cu.remote, o) for o in list_objects(root_to_mds)] + folder_urls = [ + os.path.join(cu.remote, o) + for o in list_objects(root_to_mds) + if not o.endswith('.json') + ] else: - folder_urls = [os.path.join(cu.local, o) for o in list_objects(root_to_mds)] + folder_urls = [ + os.path.join(cu.local, o) + for o in list_objects(root_to_mds) + if not o.endswith('.json') + ] + + print('I am here 3.2') + for fu in folder_urls: + print(list_objects(fu)) do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) diff --git a/tests/test_util.py b/tests/test_util.py index 555649261..2f280f96d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,7 +12,7 @@ from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path -from streaming.base.storage.download import download_file +from streaming.base.storage.download import download_file, list_objects from streaming.base.storage.upload import CloudUploader from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, get_list_arg, merge_index, number_abbrev_to_int, retry) @@ -28,8 +28,9 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = tempfile.mkdtemp() @@ -149,26 +150,44 @@ def test_clean_stale_shared_memory(): _ = BuiltinSharedMemory(name, False, 64) -def integrity_check(out: Union[str, Tuple[str, str]], keep_local: bool): +def integrity_check(out: Union[str, Tuple[str, str]], + keep_local: bool, + expected_n_shard_files: int = -1): """ Check if merged_index file has integrity If merged_index is a cloud url, first download it to a temp local file. Args: out (Union[str, Tuple[str,str]]): folder that merged index.json resides keep_local: whether to check local file + expected_n_shard_files (int): If -1, find the number in out with get_expected() """ + def get_expected(mds_root: str): + n_shard_files = 0 + for o in list_objects(mds_root): + print('I am here 4.1: ', o) + if o.endswith('.json'): + continue + for b in list_objects(os.path.join(mds_root, o)): + print(f'I am here 4.2 for {os.path.join(mds_root, o)}: ', b) + if b.endswith('.mds'): + n_shard_files += 1 + return n_shard_files + cu = CloudUploader.get(out, keep_local=True, exist_ok=True) with tempfile.TemporaryDirectory() as temp_dir: - if cu.remote: download_file(os.path.join(cu.remote, 'index.json'), os.path.join(temp_dir, 'index.json'), timeout=60) + if expected_n_shard_files == -1: + expected_n_shard_files = get_expected(cu.remote) local_merged_index_path = os.path.join(temp_dir, 'index.json') else: local_merged_index_path = os.path.join(cu.local, 'index.json') + if expected_n_shard_files == -1: + expected_n_shard_files = get_expected(cu.local) if not keep_local: assert not os.path.exists(os.path.join(cu.local, 'index.json')) @@ -177,11 +196,13 @@ def integrity_check(out: Union[str, Tuple[str, str]], keep_local: bool): assert os.path.exists(local_merged_index_path) merged_index = json.load(open(local_merged_index_path, 'r')) n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) - assert (n_shard_files == 2), 'expected 2 shard files but got {n_shard_files}' + assert (n_shard_files == expected_n_shard_files + ), f'expected {expected_n_shard_files} shard files but got {n_shard_files}' @pytest.mark.parametrize('output_format', ['local', 'remote']) -def test_merge_index(manual_integration_dir: Any, output_format: str): +@pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) +def test_merge_index(manual_integration_dir: Any, output_format: str, n_partitions: int): from decimal import Decimal from pyspark.sql import SparkSession @@ -207,7 +228,7 @@ def test_merge_index(manual_integration_dir: Any, output_format: str): data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), (3, 'Charlie', Decimal('987.65'))] - df = spark.createDataFrame(data=data, schema=schema).repartition(3) + df = spark.createDataFrame(data=data, schema=schema).repartition(n_partitions) mds_kwargs = { 'out': out, @@ -301,7 +322,7 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, folder_ur folder_urls.append('gs://' + MY_BUCKET + '/' + s) do_merge_index(folder_urls, out, keep_local=keep_local) - integrity_check(out, keep_local=keep_local) + integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) @pytest.mark.parametrize('with_args', [True, False]) From 19723027c979ba853ef82c3f72e6f3be19848e0f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sun, 1 Oct 2023 22:51:26 -0700 Subject: [PATCH 22/59] updates --- streaming/base/converters/dataframe_to_mds.py | 22 ++++++-------- .../base/converters/test_dataframe_to_mds.py | 30 +++++++++---------- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 04067c28f..3ddb576f5 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -158,11 +158,10 @@ def write_mds(iterator: Iterable): else: raise RuntimeError('TaskContext.get() returns None') - if isinstance(mds_path, str): # local - output_path = output = os.path.join(mds_path, f'{id}') + if mds_path[1] == '': + output = (os.path.join(mds_path[0], f'{id}'), '') else: output = (os.path.join(mds_path[0], f'{id}'), os.path.join(mds_path[1], f'{id}')) - output_path = ','.join(output) if mds_kwargs: kwargs = mds_kwargs.copy() @@ -194,7 +193,8 @@ def write_mds(iterator: Iterable): count += 1 yield pd.concat( - [pd.Series([output_path], name='mds_path'), + [pd.Series([output[0]], name='mds_path_local'), + pd.Series([output[1]], name='mds_path_remote'), pd.Series([count], name='fail_count')], axis=1) @@ -232,26 +232,22 @@ def write_mds(iterator: Iterable): cu = CloudUploader.get(out, keep_local=keep_local) print('cu.local = ', cu.local) - # Fix output format as mds_path: Tuple => remote Str => local only + # Fix output format as mds_path: Tuple(local, remote) if cu.remote is None: - mds_path = cu.local + mds_path = (cu.local, "") else: mds_path = (cu.local, cu.remote) # Prepare partition schema result_schema = StructType([ - StructField('mds_path', StringType(), False), + StructField('mds_path_local', StringType(), False), + StructField('mds_path_remote', StringType(), False), StructField('fail_count', IntegerType(), False) ]) partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() if merge_index: - folder_urls = [] - for row in partitions: - if ',' in row['mds_path']: - folder_urls.append(row['mds_path'].split(',')) - else: - folder_urls.append(row['mds_path']) + folder_urls = [ (row['mds_path_local'], row['mds_path_remote']) for row in partitions ] do_merge_index(folder_urls, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index cd936f774..ad18c02a8 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -17,7 +17,7 @@ from streaming.base.converters import dataframeToMDS MY_PREFIX = 'train' -MY_BUCKET = 'mosaicml-composer-tests' +MY_BUCKET= {'gs://': 'mosaicml-composer-tests', 's3://': 'streaming-upload-test-bucket'} MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -27,30 +27,30 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = mkdtemp() - def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: + def _method(cloud_prefix: str = 's3://') -> Tuple[str, str]: mock_local_dir = tmp_dir # mkdtemp() - mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX) + mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET[cloud_prefix], MY_PREFIX) return mock_local_dir, mock_remote_dir try: yield _method finally: shutil.rmtree(tmp_dir, ignore_errors=True) # pyright: ignore - if MANUAL_INTEGRATION_TEST: - try: - from google.cloud.storage import Client - storage_client = Client() - bucket = storage_client.get_bucket(MY_BUCKET) - blobs = bucket.list_blobs(prefix=MY_PREFIX) - for blob in blobs: - blob.delete() - except ImportError: - raise ImportError('google.cloud.storage is not imported correctly.') + #if MANUAL_INTEGRATION_TEST: + # try: + # from google.cloud.storage import Client + # storage_client = Client() + # bucket = storage_client.get_bucket(MY_BUCKET) + # blobs = bucket.list_blobs(prefix=MY_PREFIX) + # for blob in blobs: + # blob.delete() + # except ImportError: + # raise ImportError('google.cloud.storage is not imported correctly.') class TestDataFrameToMDS: From 38f495f1b165dcbdc785ecfd00c16f30702fdd70 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sun, 1 Oct 2023 22:51:53 -0700 Subject: [PATCH 23/59] Fix lints --- streaming/base/converters/dataframe_to_mds.py | 15 ++++++++------- tests/base/converters/test_dataframe_to_mds.py | 5 +++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 3ddb576f5..e00bb17fc 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -192,11 +192,12 @@ def write_mds(iterator: Iterable): raise RuntimeError(f'failed to write sample: {sample}') from ex count += 1 - yield pd.concat( - [pd.Series([output[0]], name='mds_path_local'), - pd.Series([output[1]], name='mds_path_remote'), - pd.Series([count], name='fail_count')], - axis=1) + yield pd.concat([ + pd.Series([output[0]], name='mds_path_local'), + pd.Series([output[1]], name='mds_path_remote'), + pd.Series([count], name='fail_count') + ], + axis=1) if dataframe is None or dataframe.isEmpty(): raise ValueError(f'Input dataframe is None or Empty!') @@ -234,7 +235,7 @@ def write_mds(iterator: Iterable): # Fix output format as mds_path: Tuple(local, remote) if cu.remote is None: - mds_path = (cu.local, "") + mds_path = (cu.local, '') else: mds_path = (cu.local, cu.remote) @@ -247,7 +248,7 @@ def write_mds(iterator: Iterable): partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() if merge_index: - folder_urls = [ (row['mds_path_local'], row['mds_path_remote']) for row in partitions ] + folder_urls = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] do_merge_index(folder_urls, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index ad18c02a8..80b3eae98 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -17,7 +17,7 @@ from streaming.base.converters import dataframeToMDS MY_PREFIX = 'train' -MY_BUCKET= {'gs://': 'mosaicml-composer-tests', 's3://': 'streaming-upload-test-bucket'} +MY_BUCKET = {'gs://': 'mosaicml-composer-tests', 's3://': 'streaming-upload-test-bucket'} MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -28,7 +28,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' tmp_dir = mkdtemp() From a5a4ffc748c7c9b384401347c8b3d72097b19b41 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 3 Oct 2023 12:11:55 -0700 Subject: [PATCH 24/59] use new list_objects --- streaming/base/converters/dataframe_to_mds.py | 7 +- streaming/base/storage/download.py | 243 +----------------- streaming/base/storage/upload.py | 117 ++++++++- streaming/base/util.py | 94 ++----- .../base/converters/test_dataframe_to_mds.py | 55 ++-- tests/test_util.py | 139 ++++------ 6 files changed, 227 insertions(+), 428 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index e00bb17fc..73391f70d 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -25,6 +25,7 @@ raise e from streaming import MDSWriter +from streaming.base.format.index import get_index_basename from streaming.base.format.mds.encodings import _encodings from streaming.base.storage.upload import CloudUploader @@ -193,8 +194,9 @@ def write_mds(iterator: Iterable): count += 1 yield pd.concat([ - pd.Series([output[0]], name='mds_path_local'), - pd.Series([output[1]], name='mds_path_remote'), + pd.Series([os.path.join(output[0], get_index_basename())], name='mds_path_local'), + pd.Series([os.path.join(output[1], get_index_basename()) if output[1] != '' else ''], + name='mds_path_remote'), pd.Series([count], name='fail_count') ], axis=1) @@ -231,7 +233,6 @@ def write_mds(iterator: Iterable): out = mds_kwargs['out'] keep_local = False if 'keep_local' not in mds_kwargs else mds_kwargs['keep_local'] cu = CloudUploader.get(out, keep_local=keep_local) - print('cu.local = ', cu.local) # Fix output format as mds_path: Tuple(local, remote) if cu.remote is None: diff --git a/streaming/base/storage/download.py b/streaming/base/storage/download.py index 74c895ea3..9db4af328 100644 --- a/streaming/base/storage/download.py +++ b/streaming/base/storage/download.py @@ -8,7 +8,7 @@ import shutil import urllib.parse from time import sleep, time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from streaming.base.util import get_import_exception_message @@ -489,244 +489,3 @@ def wait_for_download(local: str, timeout: float = 60) -> None: raise TimeoutError( f'Waited longer than {timeout}s for other worker to download {local}.') sleep(0.25) - - -def remove_prefix(obj: str): - """Remove prefix from ob. - - Args: - obj (str): take form of 'path/to/folder' - return: - (str): take form of 'to/folder' - """ - return obj # '/'.join(obj.strip('/').split('/')[1:]) - - -def list_objects_from_s3(remote: str, timeout: float = 60) -> Optional[List[str]]: - """List objects from remote AWS S3. - - Args: - remote (str): Remote path (S3). - timeout (float): How long to wait for objects to be returned. - """ - import boto3 - from botocore import UNSIGNED - from botocore.config import Config - from botocore.exceptions import ClientError, NoCredentialsError - - def _list_objects(obj: urllib.parse.ParseResult, unsigned: bool = False) -> List[str]: - """List the objects from AWS S3 bucket. The bucket can be either public or private. - - Args: - obj (ParseResult): ParseResult object of remote. - unsigned (bool, optional): Set to True if it is a public bucket. - Defaults to ``False``. - """ - if unsigned: - # Client will be using unsigned mode in which public - # resources can be accessed without credentials - config = Config(read_timeout=timeout, signature_version=UNSIGNED) - else: - config = Config(read_timeout=timeout) - - # Create a new session per thread - session = boto3.session.Session() - # Create a resource client using a thread's session object - s3 = session.client('s3', config=config, endpoint_url=os.environ.get('S3_ENDPOINT_URL')) - # Threads calling S3 operations return RuntimeError (cannot schedule new futures after - # interpreter shutdown). Temporary solution is to have `use_threads` as `False`. - # Issue: https://github.com/boto/boto3/issues/3113 - paginator = s3.get_paginator('list_objects_v2') - pages = paginator.paginate(Bucket=obj.netloc, Prefix=obj.path) - ans = [] - for page in pages: - if 'Contents' in page: - for o in page['Contents']: - ans.append(remove_prefix(o['Key'])) - return ans - - obj = urllib.parse.urlparse(remote) - if obj.scheme != 's3': - raise ValueError( - f'Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote}') - - try: - return _list_objects(obj) - except NoCredentialsError: - # Public S3 buckets without credentials - return _list_objects(obj, unsigned=True) - except ClientError as e: - if e.response['Error']['Code'] in BOTOCORE_CLIENT_ERROR_CODES: - e.args = (f'Object {remote} not found! Either check the bucket path or the bucket ' + - f'permission. If the bucket is a requester pays bucket, then provide the ' + - f'bucket name to the environment variable ' + - f'`MOSAICML_STREAMING_AWS_REQUESTER_PAYS`.',) - raise e - elif e.response['Error']['Code'] == '400': - # Public S3 buckets without credentials - return _list_objects(obj, unsigned=True) - except Exception: - raise - - -def list_objects_from_gcs(remote: str, timeout: float = 60) -> Optional[List[str]]: - """List objects from remote Google Cloud Bucket. - - Args: - remote (str): Remote path (S3). - timeout (float): How long to wait for objects to be returned. - """ - from google.auth.exceptions import DefaultCredentialsError - - def _gcs_with_hmac(obj: urllib.parse.ParseResult) -> Optional[List[str]]: - """Return a list of objects from remote GCS using user level credentials. - - Args: - obj (ParseResult): ParseResult object of remote. - """ - import boto3 - from botocore.exceptions import ClientError - - # Create a new session per thread - session = boto3.session.Session() - # Create a resource client using a thread's session object - gcs_client = session.client('s3', - region_name='auto', - endpoint_url='https://storage.googleapis.com', - aws_access_key_id=os.environ['GCS_KEY'], - aws_secret_access_key=os.environ['GCS_SECRET']) - try: - response = gcs_client.list_objects_v2(Bucket=obj.netloc, Prefix=obj.path.lstrip('/')) - if response and 'Contents' in response: - return [remove_prefix(ob['Key']) for ob in response['Contents']] - - except ClientError as e: - if e.response['Error']['Code'] in BOTOCORE_CLIENT_ERROR_CODES: - raise FileNotFoundError( - f'Object {obj.scheme}, {obj.netloc}, {obj.path} not found.') from e - except Exception: - raise - - def _gcs_with_service_account(obj: urllib.parse.ParseResult) -> Optional[List[str]]: - """Return a list of objects from remote GCS using service account credentials. - - Args: - obj (ParseResult): ParseResult object of remote path (GCS). - """ - from google.auth import default as default_auth - from google.cloud.storage import Client - - credentials, _ = default_auth() - gcs_client = Client(credentials=credentials) - bucket = gcs_client.get_bucket(obj.netloc, timeout=60.0) - objects = bucket.list_blobs(prefix=obj.path.lstrip('/')) - ans = [] - for ob in objects: - ans.append(remove_prefix(ob.name)) - return ans - - obj = urllib.parse.urlparse(remote) - if obj.scheme != 'gs': - raise ValueError( - f'Expected obj.scheme to be `gs`, instead, got {obj.scheme} for remote={remote}') - - if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: - try: - return _gcs_with_hmac(obj) - except (DefaultCredentialsError, EnvironmentError): - raise ValueError(GCS_ERROR_NO_AUTHENTICATION) - else: - try: - return _gcs_with_service_account(obj) - except (DefaultCredentialsError, EnvironmentError): - raise ValueError(GCS_ERROR_NO_AUTHENTICATION) - - -def list_objects_from_oci(remote: str) -> Optional[List[str]]: - """List objects from remote OCI to local. - - Args: - remote (str): Remote path (OCI). - """ - import oci - config = oci.config.from_file() - client = oci.object_storage.ObjectStorageClient( - config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY) - - obj = urllib.parse.urlparse(remote) - if obj.scheme != 'oci': - raise ValueError( - f'Expected obj.scheme to be `oci`, instead, got {obj.scheme} for remote={remote}') - - object_names = [''] - next_start_with = None - response_complete = False - namespace = client.get_namespace().data - while not response_complete: - response = client.list_objects(namespace_name=namespace, - bucket_name=obj.netloc.split('@' + namespace)[0], - prefix=obj.path.strip('/'), - start=next_start_with).data - object_names.extend([remove_prefix(obj.name) for obj in response.objects]) - next_start_with = response.next_start_with - if not next_start_with: - response_complete = True - - return object_names - - -def list_objects_from_local(path: Optional[str]) -> List[str]: - """List objects from a local directory. - - Args: - path (str): absolute path or None. - - Notes: - List current directory if path is None. - Raise error if path is a file - """ - if not path: - return os.listdir() - return os.listdir(path) - - -def list_objects(remote: Optional[str]) -> List[str]: - """Use the correct cloud handler to list objects. - - Args: - remote (str, optional): Remote path (local filesystem). - If remote is None or '', list current working directory with os.listdir() - """ - if not remote: # '' or None - return list_objects_from_local(remote) - - # fix paths for windows - if remote: - remote = remote.replace('\\', '/') - - obj = urllib.parse.urlparse(remote) - - if obj.scheme == '': - return list_objects_from_local(remote) - elif obj.scheme == 's3': - ans = list_objects_from_s3(remote) - elif obj.scheme == 'gs': - ans = list_objects_from_gcs(remote) - elif obj.scheme == 'oci': - ans = list_objects_from_oci(remote) - else: - raise NotImplementedError - - if not ans: - return [''] - level_one_list = [] - for o in ans: - print('I am here 5', o) - suffix = o[len(obj.path):] - print('I am here 5.1', suffix, obj.path) - if '/' in suffix.strip('/'): - level_one_list.append(os.path.dirname(suffix)) - print('I am here 5.2', os.path.dirname(suffix)) - else: - level_one_list.append(suffix) - return list(set(level_one_list)) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index a506377fa..b98dfaeef 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -11,7 +11,7 @@ import urllib.parse from enum import Enum from tempfile import mkdtemp -from typing import Any, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import tqdm @@ -179,6 +179,17 @@ def upload_file(self, filename: str): """ raise NotImplementedError('Override this method in your sub-class') + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """List all objects in the object store with the given prefix. + + Args: + prefix (Optional[str], optional): The prefix to search for. Defaults to None. + + Returns: + List[str]: A list of object names that match the prefix. + """ + raise NotImplementedError(f'{type(self).__name__}.list_objects is not implemented') + def clear_local(self, local: str): """Remove the local file if it is enabled. @@ -278,6 +289,29 @@ def check_bucket_exists(self, remote: str): f'or check the bucket permission.',) raise error + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """List all objects in the S3 object store with the given prefix. + + Args: + prefix (Optional[str], optional): The prefix to search for. Defaults to None. + + Returns: + List[str]: A list of object names that match the prefix. + """ + if prefix is None: + prefix = '' + + obj = urllib.parse.urlparse(self.remote) + bucket_name = obj.netloc + prefix = os.path.join(obj.path.lstrip('/'), prefix) + + paginator = self.s3.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) + try: + return [obj['Key'] for page in pages for obj in page['Contents']] + except KeyError: + return [] + class GCSUploader(CloudUploader): """Upload file from local machine to Google Cloud Storage bucket. @@ -396,6 +430,35 @@ def check_bucket_exists(self, remote: str) -> None: elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT: self.gcs_client.get_bucket(bucket_name) + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """List all objects in the GCS object store with the given prefix. + + Args: + prefix (Optional[str], optional): The prefix to search for. Defaults to None. + + Returns: + List[str]: A list of object names that match the prefix. + """ + if prefix is None: + prefix = '' + + if self.authentication == GCSAuthentication.HMAC: + obj = urllib.parse.urlparse(self.remote) + bucket_name = obj.netloc + prefix = os.path.join(obj.path.lstrip('/'), prefix) + + paginator = self.s3.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) + try: + return [obj['Key'] for page in pages for obj in page['Contents']] + except KeyError: + return [] + elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT: + from google.cloud.storage import Blob, Bucket + + blob = Blob(obj.path.lstrip('/'), Bucket(self.gcs_client, obj.netloc)) + return [blob.name for blob in blob.bucket.list_blobs(prefix=prefix)] + class OCIUploader(CloudUploader): """Upload file from local machine to Oracle Cloud Infrastructure (OCI) Cloud Storage. @@ -488,6 +551,40 @@ def check_bucket_exists(self, remote: str): f'Check the bucket permission or create the bucket.',) raise error + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """List all objects in the OCI object store with the given prefix. + + Args: + prefix (Optional[str], optional): The prefix to search for. Defaults to None. + + Returns: + List[str]: A list of object names that match the prefix. + """ + if prefix is None: + prefix = '' + + obj = urllib.parse.urlparse(self.remote) + bucket_name = obj.netloc.split('@' + self.namespace)[0] + prefix = os.path.join(obj.path.strip('/'), prefix) + + object_names = [] + next_start_with = None + response_complete = False + try: + while not response_complete: + response = self.client.list_objects(namespace_name=self.namespace, + bucket_name=bucket_name, + prefix=prefix, + start=next_start_with).data + object_names.extend([resp_obj.name for resp_obj in response.objects]) + next_start_with = response.next_start_with + if not next_start_with: + response_complete = True + except Exception as e: + _reraise_oci_errors(self.get_uri(prefix), e) + + return object_names + class AzureUploader(CloudUploader): """Upload file from local machine to Microsoft Azure bucket. @@ -863,3 +960,21 @@ def _upload_file(): self.clear_local(local=local_filename) _upload_file() + + def list_objects(self, prefix: Optional[str] = None) -> List[str]: + """List all objects locally with the given prefix. + + Args: + prefix (Optional[str], optional): The prefix to search for. Defaults to None. + + Returns: + List[str]: A list of object names that match the prefix. + """ + if prefix is None: + prefix = '.' + + ans = [] + for dirpath, _, files in os.walk(prefix): + for file in files: + ans.append(os.path.join(dirpath, file)) + return ans diff --git a/streaming/base/util.py b/streaming/base/util.py index cde5df4a0..c60433ab4 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -13,6 +13,7 @@ import tempfile import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory +from pathlib import Path from time import sleep, time from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload @@ -214,17 +215,17 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def do_merge_index(folder_urls: List[Any], +def do_merge_index(index_file_urls: List[Any], out: Union[str, Tuple[str, str]], keep_local: bool = True, download_timeout: int = 60) -> None: - """Merge index.json from a list of directories. Write to `out`, overwriting if exists. + """Merge index.json from a list of index.json. Write to `out`, overwriting if exists. Args: - folder_urls (Union[str, Tuple[str,str]]): folders that contain index.json for the partition + index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions each element can take the form of a single path string or a tuple string. - The pattern of folder_urls and corresponding reaction is one of: + The pattern of index_file_urls and corresponding reaction is one of: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are tuple (local, remote). All urls are accessible locally -> no download 3. All urls are tuple (local, remote). At least one url is not accessible locally -> download all @@ -237,19 +238,19 @@ def do_merge_index(folder_urls: List[Any], from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader - if not folder_urls or not out: - logger.warning('Need to specify both folder_urls and out. No index merged') + if not index_file_urls or not out: + logger.warning('Need to specify both index_file_urls and out. No index merged') return - print('folder_urls = ', folder_urls) # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() cu = CloudUploader.get(out, keep_local=True, exist_ok=True) - # Remove '/' from right, so os.path.basename gives relative path to each folder + # Remove duplicates, and strip '/' from right if any + index_file_urls = list(set(index_file_urls)) urls = [] - for url in folder_urls: + for url in index_file_urls: if isinstance(url, str): urls.append(url.rstrip('/').strip()) else: @@ -260,7 +261,7 @@ def do_merge_index(folder_urls: List[Any], for url in urls: if isinstance(url, tuple): # If driver cannot access the local path, download = True - download = not os.path.exists(os.path.join(url[0], index_basename)) + download = not os.path.exists(url[0]) else: # If url is a remote, download = True, False otherwise download = urllib.parse.urlparse(url).scheme != '' @@ -285,22 +286,19 @@ def do_merge_index(folder_urls: List[Any], # If download is needed, download url from remote to temp_root path = urllib.parse.urlparse(remote).path local = os.path.join(temp_root, path.lstrip('/')) - remote_url = os.path.join(remote, index_basename) - local_path = os.path.join(local, index_basename) try: - download_file(remote_url, local_path, download_timeout) + download_file(remote, local, download_timeout) except Exception as ex: - raise RuntimeError(f'Failed to download index.json: {remote_url}') from ex + raise RuntimeError(f'Failed to download index.json: {remote}') from ex if not (os.path.exists(local)): - raise FileNotFoundError(f'Folder {local} does not exist or not accessible.') + raise FileNotFoundError(f'Index file {local} does not exist or not accessible.') partitions.append(local) # merge index files into shards shards = [] - for partition in partitions: - partition_index = f'{partition}/{index_basename}' - mds_partition_basename = os.path.basename(partition) + for partition_index in partitions: + p = Path(partition_index) obj = json.load(open(partition_index)) for i in range(len(obj['shards'])): shard = obj['shards'][i] @@ -308,7 +306,7 @@ def do_merge_index(folder_urls: List[Any], if shard.get(key): basename = shard[key]['basename'] obj['shards'][i][key]['basename'] = os.path.join( - mds_partition_basename, basename) + os.path.basename(p.parent), basename) shards += obj['shards'] # Save merged index locally @@ -331,64 +329,6 @@ def do_merge_index(folder_urls: List[Any], shutil.rmtree(cu.local, ignore_errors=True) -def merge_index(root_to_mds: Union[str, Tuple[str, str]], - *, - keep_local: bool = True, - overwrite: bool = True) -> None: - """Merge index.json given the root of MDS dataset. Write merged index to the root folder. - - Args: - root_to_mds (Union[str, Tuple[str,str]]): folders that contain MDS partitions. - It can be local str or remote str or (local, remote) - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - overwrite (bool): Overwrite merged index file in out if there exists one. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. - """ - from streaming.base.storage.download import list_objects - from streaming.base.storage.upload import CloudUploader - - if not root_to_mds: - logger.warning('No MDS dataset folder specified, no index merged') - return - - cu = CloudUploader.get(root_to_mds, exist_ok=True, keep_local=True) - if isinstance(root_to_mds, tuple): - local_folders = [ - os.path.join(cu.local, o) - for o in list_objects(root_to_mds[0]) - if not o.endswith('.json') - ] - remote_folders = [ - os.path.join(cu.remote, o) - for o in list_objects(root_to_mds[1]) - if not o.endswith('.json') - ] - folder_urls = list(zip(local_folders, remote_folders)) - else: - print('I am here 3.1', root_to_mds) - print('I am here 3', list_objects(root_to_mds)) - if cu.remote: - folder_urls = [ - os.path.join(cu.remote, o) - for o in list_objects(root_to_mds) - if not o.endswith('.json') - ] - else: - folder_urls = [ - os.path.join(cu.local, o) - for o in list_objects(root_to_mds) - if not o.endswith('.json') - ] - - print('I am here 3.2') - for fu in folder_urls: - print(list_objects(fu)) - - do_merge_index(folder_urls, root_to_mds, keep_local=keep_local, download_timeout=60) - - return - - @overload def retry( exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 80b3eae98..fa766a8d3 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -17,7 +17,10 @@ from streaming.base.converters import dataframeToMDS MY_PREFIX = 'train' -MY_BUCKET = {'gs://': 'mosaicml-composer-tests', 's3://': 'streaming-upload-test-bucket'} +MY_BUCKET = { + 'gs://': 'mosaicml-composer-tests', + 's3://': 'mosaicml-internal-temporary-composer-testing' +} MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -27,13 +30,19 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' os.environ[ 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + os.environ.pop('AWS_ACCESS_KEY_ID', None) + os.environ.pop('AWS_SECRET_ACCESS_KEY', None) + os.environ.pop('AWS_SECURITY_TOKEN', None) + os.environ.pop('AWS_SESSION_TOKEN', None) + os.environ['AWS_PROFILE'] = 'temporary' + tmp_dir = mkdtemp() - def _method(cloud_prefix: str = 's3://') -> Tuple[str, str]: + def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: mock_local_dir = tmp_dir # mkdtemp() mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET[cloud_prefix], MY_PREFIX) return mock_local_dir, mock_remote_dir @@ -42,16 +51,16 @@ def _method(cloud_prefix: str = 's3://') -> Tuple[str, str]: yield _method finally: shutil.rmtree(tmp_dir, ignore_errors=True) # pyright: ignore - #if MANUAL_INTEGRATION_TEST: - # try: - # from google.cloud.storage import Client - # storage_client = Client() - # bucket = storage_client.get_bucket(MY_BUCKET) - # blobs = bucket.list_blobs(prefix=MY_PREFIX) - # for blob in blobs: - # blob.delete() - # except ImportError: - # raise ImportError('google.cloud.storage is not imported correctly.') + if MANUAL_INTEGRATION_TEST: + try: + from google.cloud.storage import Client + storage_client = Client() + bucket = storage_client.get_bucket(MY_BUCKET['gs://']) + blobs = bucket.list_blobs(prefix=MY_PREFIX) + for blob in blobs: + blob.delete() + except ImportError: + raise ImportError('google.cloud.storage is not imported correctly.') class TestDataFrameToMDS: @@ -220,9 +229,9 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer assert not (os.path.exists(os.path.join( out, 'index.json'))), 'merged index is created when merge_index=False' - @pytest.mark.parametrize('scheme', ['gs']) - @pytest.mark.parametrize('keep_local', [True, False]) - @pytest.mark.parametrize('merge_index', [True, False]) + @pytest.mark.parametrize('scheme', ['gs://', 's3://']) + @pytest.mark.parametrize('keep_local', [True]) # , False]) + @pytest.mark.parametrize('merge_index', [True]) # , False]) @pytest.mark.usefixtures('manual_integration_dir') def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, merge_index: bool, keep_local: bool, @@ -231,7 +240,7 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, pytest.skip( 'Overlap with integration tests. But better figure out how to run this test ' + 'suite with Mock.') - mock_local, mock_remote = manual_integration_dir() + mock_local, mock_remote = manual_integration_dir(scheme) out = (mock_local, mock_remote) mds_kwargs = { 'out': out, @@ -269,15 +278,17 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, assert not (os.path.exists(os.path.join( mds_path[0], 'index.json'))), 'merged index is created when merge_index=False' + @pytest.mark.parametrize('scheme', ['gs://', 's3://']) @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.parametrize('merge_index', [True, False]) @pytest.mark.usefixtures('manual_integration_dir') def test_integration_conversion_local_and_remote(self, dataframe: Any, manual_integration_dir: Any, - merge_index: bool, keep_local: bool): + merge_index: bool, keep_local: bool, + scheme: str): if not MANUAL_INTEGRATION_TEST: pytest.skip('run local only. CI cluster does not have GCS service acct set up.') - out = manual_integration_dir() + out = manual_integration_dir(scheme) mds_kwargs = { 'out': out, 'columns': { @@ -312,11 +323,13 @@ def test_integration_conversion_local_and_remote(self, dataframe: Any, f'merged index is created at {mds_path[0]} when merge_index={merge_index} and ' + f'keep_local={keep_local}') + @pytest.mark.parametrize('scheme', ['gs://', 's3://']) @pytest.mark.usefixtures('manual_integration_dir') - def test_integration_conversion_remote_only(self, dataframe: Any, manual_integration_dir: Any): + def test_integration_conversion_remote_only(self, dataframe: Any, manual_integration_dir: Any, + scheme: str): if not MANUAL_INTEGRATION_TEST: pytest.skip('run local only. CI cluster does not have GCS service acct set up.') - _, remote = manual_integration_dir() + _, remote = manual_integration_dir('s3://') mds_kwargs = { 'out': remote, 'columns': { diff --git a/tests/test_util.py b/tests/test_util.py index 2f280f96d..347b3c80d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,13 +12,16 @@ from streaming.base.constant import RESUME from streaming.base.shared.prefix import _get_path -from streaming.base.storage.download import download_file, list_objects +from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, - get_list_arg, merge_index, number_abbrev_to_int, retry) + get_list_arg, number_abbrev_to_int, retry) MY_PREFIX = 'train' -MY_BUCKET = 'mosaicml-composer-tests' +MY_BUCKET = { + 'gs://': 'mosaicml-composer-tests', + 's3://': 'mosaicml-internal-temporary-composer-testing' +} MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -32,11 +35,17 @@ def manual_integration_dir() -> Any: os.environ[ 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' + os.environ.pop('AWS_ACCESS_KEY_ID', None) + os.environ.pop('AWS_SECRET_ACCESS_KEY', None) + os.environ.pop('AWS_SECURITY_TOKEN', None) + os.environ.pop('AWS_SESSION_TOKEN', None) + os.environ['AWS_PROFILE'] = 'temporary' + tmp_dir = tempfile.mkdtemp() def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: mock_local_dir = tmp_dir - mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX) + mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET[cloud_prefix], MY_PREFIX) return mock_local_dir, mock_remote_dir try: @@ -47,7 +56,7 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: try: from google.cloud.storage import Client storage_client = Client() - bucket = storage_client.get_bucket(MY_BUCKET) + bucket = storage_client.get_bucket(MY_BUCKET['gs://']) blobs = bucket.list_blobs(prefix=MY_PREFIX) for blob in blobs: blob.delete() @@ -164,14 +173,10 @@ def integrity_check(out: Union[str, Tuple[str, str]], def get_expected(mds_root: str): n_shard_files = 0 - for o in list_objects(mds_root): - print('I am here 4.1: ', o) - if o.endswith('.json'): - continue - for b in list_objects(os.path.join(mds_root, o)): - print(f'I am here 4.2 for {os.path.join(mds_root, o)}: ', b) - if b.endswith('.mds'): - n_shard_files += 1 + cu = CloudUploader.get(mds_root, exist_ok=True, keep_local=True) + for o in cu.list_objects(): + if o.endswith('.mds'): + n_shard_files += 1 return n_shard_files cu = CloudUploader.get(out, keep_local=True, exist_ok=True) @@ -200,56 +205,14 @@ def get_expected(mds_root: str): ), f'expected {expected_n_shard_files} shard files but got {n_shard_files}' -@pytest.mark.parametrize('output_format', ['local', 'remote']) -@pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) -def test_merge_index(manual_integration_dir: Any, output_format: str, n_partitions: int): - from decimal import Decimal - - from pyspark.sql import SparkSession - from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - - from streaming.base.converters import dataframeToMDS - - if output_format == 'remote': - if not MANUAL_INTEGRATION_TEST: - pytest.skip('Require cloud credentials. ' + - 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - _, out = manual_integration_dir() - else: - out, _ = manual_integration_dir() - - spark = SparkSession.builder.getOrCreate() # pyright: ignore - schema = StructType([ - StructField('id', IntegerType(), nullable=False), - StructField('name', StringType(), nullable=False), - StructField('amount', DecimalType(10, 2), nullable=False) - ]) - - data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), - (3, 'Charlie', Decimal('987.65'))] - - df = spark.createDataFrame(data=data, schema=schema).repartition(n_partitions) - - mds_kwargs = { - 'out': out, - 'columns': { - 'id': 'int', - 'name': 'str' - }, - } - - dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - merge_index(out) - integrity_check(out, keep_local=True) - - -@pytest.mark.parametrize('folder_urls_pattern', [1, 2, 3, 4, 5]) +@pytest.mark.parametrize('scheme', ['gs://', 's3://']) +@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3, 4, 5]) @pytest.mark.parametrize('output_format', ['local', 'remote', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, folder_urls_pattern: int, - output_format: str): - """Validate the final merge index json for following patterns of folder_urls: +def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, + index_file_urls_pattern: int, output_format: str, scheme: str): + """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error 3. All urls are tuple (local, remote). All urls are accessible locally -> no download @@ -262,65 +225,73 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, folder_ur pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') if output_format == 'remote': - out = manual_integration_dir()[1] + out = manual_integration_dir(scheme)[1] else: - out = manual_integration_dir() + out = manual_integration_dir(scheme) else: - out = manual_integration_dir()[0] + out = manual_integration_dir(scheme)[0] naive_mds_partitions = [ - 'tests/resources/naive_MDSdataset/25/', 'tests/resources/naive_MDSdataset/26/', + 'tests/resources/naive_MDSdataset/25/', 'tests/resources/naive_MDSdataset/26', 'tests/resources/naive_MDSdataset/27/' ] - if folder_urls_pattern == 1: - folder_urls = [os.getcwd() + '/' + s for s in naive_mds_partitions] - do_merge_index(folder_urls, out, keep_local=keep_local) + if index_file_urls_pattern == 1: + index_file_urls = [ + os.path.join(os.getcwd(), s, 'index.json') for s in naive_mds_partitions + ] + do_merge_index(index_file_urls, out, keep_local=keep_local) - if folder_urls_pattern == 2: + if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: - folder_urls = [a_temporary_folder + '/' + s for s in naive_mds_partitions] + index_file_urls = [ + os.path.join(a_temporary_folder, s, 'index.json') for s in naive_mds_partitions + ] with pytest.raises(FileNotFoundError, match=f'.* does not exist or not accessible.*'): - do_merge_index(folder_urls, out, keep_local=keep_local) + do_merge_index(index_file_urls, out, keep_local=keep_local) return - if folder_urls_pattern == 3: - folder_urls = [] + if index_file_urls_pattern == 3: + index_file_urls = [] for s in naive_mds_partitions: - folder_urls.append((os.getcwd() + '/' + s, 'gs://mybucket/' + s)) - do_merge_index(folder_urls, out, keep_local=keep_local) + index_file_urls.append((os.path.join(os.getcwd(), s, 'index.json'), + os.path.join(scheme, MY_BUCKET[scheme], s, 'index.json'))) + do_merge_index(index_file_urls, out, keep_local=keep_local) - if folder_urls_pattern == 4: + if index_file_urls_pattern == 4: if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') with tempfile.TemporaryDirectory() as a_temporary_folder: - folder_urls = [] + index_file_urls = [] for s in naive_mds_partitions: - cu_path = (os.getcwd() + '/' + s, 'gs://' + MY_BUCKET + '/' + s) + cu_path = (os.path.join(os.getcwd(), + s), os.path.join(scheme, MY_BUCKET[scheme], s)) cu = CloudUploader.get(cu_path, keep_local=True, exist_ok=True) index_json = os.path.join(cu.local, 'index.json') if os.path.exists(index_json): cu.upload_file('index.json') - folder_urls.append((a_temporary_folder, 'gs://' + MY_BUCKET + '/' + s)) - do_merge_index(folder_urls, out, keep_local=keep_local) + index_file_urls.append((os.path.join(a_temporary_folder, s, 'index.json'), + os.path.join(scheme, MY_BUCKET[scheme], s, 'index.json'))) + do_merge_index(index_file_urls, out, keep_local=keep_local) - if folder_urls_pattern == 5: + if index_file_urls_pattern == 5: if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') with tempfile.TemporaryDirectory() as a_temporary_folder: - folder_urls = [] + index_file_urls = [] for s in naive_mds_partitions: - cu_path = (os.getcwd() + '/' + s, 'gs://' + MY_BUCKET + '/' + s) + cu_path = (os.path.join(os.getcwd(), + s), os.path.join(scheme, MY_BUCKET[scheme], s)) cu = CloudUploader.get(cu_path, keep_local=True, exist_ok=True) index_json = os.path.join(cu.local, 'index.json') if os.path.exists(index_json): cu.upload_file('index.json') - folder_urls.append('gs://' + MY_BUCKET + '/' + s) - do_merge_index(folder_urls, out, keep_local=keep_local) + index_file_urls.append(os.path.join(scheme, MY_BUCKET[scheme], s, 'index.json')) + do_merge_index(index_file_urls, out, keep_local=keep_local) integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) From 00e6a8c03f7a04b88f0eed4b52dc6516688b22af Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 3 Oct 2023 12:23:23 -0700 Subject: [PATCH 25/59] Fix lints --- streaming/base/storage/upload.py | 28 +++--- tests/test_list_objects.py | 145 ++++++++----------------------- 2 files changed, 52 insertions(+), 121 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index b98dfaeef..73cae6549 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -179,7 +179,7 @@ def upload_file(self, filename: str): """ raise NotImplementedError('Override this method in your sub-class') - def list_objects(self, prefix: Optional[str] = None) -> List[str]: + def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the object store with the given prefix. Args: @@ -289,7 +289,7 @@ def check_bucket_exists(self, remote: str): f'or check the bucket permission.',) raise error - def list_objects(self, prefix: Optional[str] = None) -> List[str]: + def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the S3 object store with the given prefix. Args: @@ -303,7 +303,7 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: obj = urllib.parse.urlparse(self.remote) bucket_name = obj.netloc - prefix = os.path.join(obj.path.lstrip('/'), prefix) + prefix = os.path.join(str(obj.path).lstrip('/'), prefix) paginator = self.s3.get_paginator('list_objects_v2') pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) @@ -430,7 +430,7 @@ def check_bucket_exists(self, remote: str) -> None: elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT: self.gcs_client.get_bucket(bucket_name) - def list_objects(self, prefix: Optional[str] = None) -> List[str]: + def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the GCS object store with the given prefix. Args: @@ -442,12 +442,13 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: if prefix is None: prefix = '' + obj = urllib.parse.urlparse(self.remote) + if self.authentication == GCSAuthentication.HMAC: - obj = urllib.parse.urlparse(self.remote) bucket_name = obj.netloc - prefix = os.path.join(obj.path.lstrip('/'), prefix) + prefix = os.path.join(str(obj.path).lstrip('/'), prefix) - paginator = self.s3.get_paginator('list_objects_v2') + paginator = self.gcs_client.get_paginator('list_objects_v2') pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) try: return [obj['Key'] for page in pages for obj in page['Contents']] @@ -456,7 +457,7 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT: from google.cloud.storage import Blob, Bucket - blob = Blob(obj.path.lstrip('/'), Bucket(self.gcs_client, obj.netloc)) + blob = Blob(str(obj.path).lstrip('/'), Bucket(self.gcs_client, obj.netloc)) return [blob.name for blob in blob.bucket.list_blobs(prefix=prefix)] @@ -551,7 +552,7 @@ def check_bucket_exists(self, remote: str): f'Check the bucket permission or create the bucket.',) raise error - def list_objects(self, prefix: Optional[str] = None) -> List[str]: + def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the OCI object store with the given prefix. Args: @@ -565,7 +566,7 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: obj = urllib.parse.urlparse(self.remote) bucket_name = obj.netloc.split('@' + self.namespace)[0] - prefix = os.path.join(obj.path.strip('/'), prefix) + prefix = os.path.join(str(obj.path).strip('/'), prefix) object_names = [] next_start_with = None @@ -580,10 +581,9 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: next_start_with = response.next_start_with if not next_start_with: response_complete = True - except Exception as e: - _reraise_oci_errors(self.get_uri(prefix), e) - - return object_names + return object_names + except Exception as _: + return [] class AzureUploader(CloudUploader): diff --git a/tests/test_list_objects.py b/tests/test_list_objects.py index 44a14a17f..f43208628 100644 --- a/tests/test_list_objects.py +++ b/tests/test_list_objects.py @@ -9,23 +9,22 @@ import boto3 import pytest -from streaming.base.storage.download import (list_objects, list_objects_from_gcs, - list_objects_from_local, list_objects_from_s3) -from tests.conftest import GCS_URL, MY_BUCKET, R2_URL +from streaming.base.storage.upload import CloudUploader +from tests.conftest import MY_BUCKET MY_PREFIX = 'train' @pytest.fixture(scope='function') -def remote_local_file() -> Any: +def remote_local_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" - def _method(cloud_prefix: str = '', filename: str = 'file.txt') -> Tuple[str, str]: + def _method(cloud_prefix: str = '') -> Tuple[str, str]: try: mock_local_dir = tempfile.TemporaryDirectory() - mock_local_filepath = os.path.join(mock_local_dir.name, filename) - mock_remote_filepath = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX, filename) - return mock_remote_filepath, mock_local_filepath + mock_local = mock_local_dir.name + mock_remote = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX) + return mock_remote, mock_local finally: mock_local_dir.cleanup() # pyright: ignore @@ -34,126 +33,58 @@ def _method(cloud_prefix: str = '', filename: str = 'file.txt') -> Tuple[str, st class TestS3Client: - @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') - def test_list_objects_from_s3(self, remote_local_file: Any): + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') + def test_list_objects_from_s3(self, remote_local_dir: Any): with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: file_name = tmp.name.split(os.sep)[-1] - mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://', filename=file_name) + mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') client = boto3.client('s3', region_name='us-east-1') client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') - objs = list_objects_from_s3(mock_remote_filepath) - assert isinstance(objs, list) - @pytest.mark.usefixtures('s3_client', 's3_test', 'r2_credentials', 'remote_local_file') - def test_list_objects_from_s3_with_endpoint_URL(self, remote_local_file: Any): - with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: - file_name = tmp.name.split(os.sep)[-1] - mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://', filename=file_name) - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') - _ = list_objects_from_s3(mock_remote_filepath) - assert os.environ['S3_ENDPOINT_URL'] == R2_URL + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + objs = cu.list_objects(mock_remote_dir) + assert isinstance(objs, list) - @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') - def test_clienterror_exception(self, remote_local_file: Any): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://') - objs = list_objects_from_s3(mock_remote_filepath) + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') + def test_clienterror_exception(self, remote_local_dir: Any): + mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + objs = cu.list_objects() if objs: assert (len(objs) == 0) - @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_file') - def test_invalid_cloud_prefix(self, remote_local_file: Any): + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') + def test_invalid_cloud_prefix(self, remote_local_dir: Any): with pytest.raises(ValueError): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='s9://') - _ = list_objects_from_s3(mock_remote_filepath) + mock_remote_dir, _ = remote_local_dir(cloud_prefix='s9://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + _ = cu.list_objects() class TestGCSClient: - @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') - def test_list_objects_from_gcs(self, remote_local_file: Any): - with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: - file_name = tmp.name.split(os.sep)[-1] - mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://', filename=file_name) - client = boto3.client('s3', - region_name='us-east-1', - endpoint_url=GCS_URL, - aws_access_key_id=os.environ['GCS_KEY'], - aws_secret_access_key=os.environ['GCS_SECRET']) - client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') - objs = list_objects_from_gcs(mock_remote_filepath) - assert isinstance(objs, list) - - @patch('google.auth.default') - @patch('google.cloud.storage.Client') - @pytest.mark.usefixtures('gcs_service_account_credentials') - @pytest.mark.parametrize('out', ['gs://bucket/dir']) - def test_download_service_account(self, mock_client: Mock, mock_default: Mock, out: str): - with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as _: - credentials_mock = Mock() - mock_default.return_value = credentials_mock, None - objs = list_objects_from_gcs(out) - mock_client.assert_called_once_with(credentials=credentials_mock) - assert isinstance(objs, list) - - @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') - def test_filenotfound_exception(self, remote_local_file: Any): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://') - _ = list_objects_from_gcs(mock_remote_filepath) - - @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_file') - def test_invalid_cloud_prefix(self, remote_local_file: Any): + @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_dir') + def test_invalid_cloud_prefix(self, remote_local_dir: Any): with pytest.raises(ValueError): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://') - _ = list_objects_from_gcs(mock_remote_filepath) + mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs9://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + _ = cu.list_objects() - def test_no_credentials_error(self, remote_local_file: Any): + def test_no_credentials_error(self, remote_local_dir: Any): """Ensure we raise a value error correctly if we have no credentials available.""" with pytest.raises(ValueError): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://') - _ = list_objects_from_gcs(mock_remote_filepath) - - -def test_list_objects_from_local(): - mock_local_dir = tempfile.TemporaryDirectory() - file_name = 'file.txt' - mock_local_file = os.path.join(mock_local_dir.name, file_name) - # Creates a new empty file - with open(mock_local_file, 'w') as _: - pass - with pytest.raises(NotADirectoryError): - _ = list_objects_from_local(mock_local_file) + mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + _ = cu.list_objects() class TestListObjects: - @patch('streaming.base.storage.download.list_objects_from_s3') - @pytest.mark.usefixtures('remote_local_file') - def test_list_objects_from_s3_gets_called(self, mocked_requests: Mock, remote_local_file: Any): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='s3://') - list_objects(mock_remote_filepath) - mocked_requests.assert_called_once() - mocked_requests.assert_called_once_with(mock_remote_filepath) - - @patch('streaming.base.storage.download.list_objects_from_gcs') - @pytest.mark.usefixtures('remote_local_file') - def test_list_objects_from_gcs_gets_called(self, mocked_requests: Mock, - remote_local_file: Any): - mock_remote_filepath, _ = remote_local_file(cloud_prefix='gs://') - list_objects(mock_remote_filepath) - mocked_requests.assert_called_once() - mocked_requests.assert_called_once_with(mock_remote_filepath) - - @patch('streaming.base.storage.download.list_objects_from_local') - @pytest.mark.usefixtures('remote_local_file') + @patch('streaming.base.storage.LocalUploader.list_objects') + @pytest.mark.usefixtures('remote_local_dir') def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, - remote_local_file: Any): - mock_remote_filepath, _ = remote_local_file() - list_objects(mock_remote_filepath) + remote_local_dir: Any): + mock_remote_dir, _ = remote_local_dir() + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu.list_objects() mocked_requests.assert_called_once() - mocked_requests.assert_called_once_with(mock_remote_filepath) - - @pytest.mark.usefixtures('remote_local_file') - def test_list_objects_invalid_missing_remote(self): - obj = list_objects(None) - assert (obj == os.listdir()) From 22429cfd35cc879b1c1b288b9600f1aeca2d8c2a Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 3 Oct 2023 12:26:39 -0700 Subject: [PATCH 26/59] remove --- tests/base/converters/test_dataframe_to_mds.py | 3 --- tests/test_util.py | 5 +---- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index fa766a8d3..b5756ca4d 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -31,9 +31,6 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' - os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) diff --git a/tests/test_util.py b/tests/test_util.py index 347b3c80d..dd66c44cf 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -31,10 +31,7 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - #os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' - + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) From 5301563eff50b52959b1c1ca3b9e72795be8a38b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 3 Oct 2023 17:11:57 -0700 Subject: [PATCH 27/59] Add merge_index --- streaming/base/converters/dataframe_to_mds.py | 16 ++-- streaming/base/storage/upload.py | 19 ++--- streaming/base/util.py | 50 +++++++++++++ .../base/converters/test_dataframe_to_mds.py | 15 +++- tests/test_util.py | 75 +++++++++++++++++-- 5 files changed, 152 insertions(+), 23 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 73391f70d..f654649d0 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -160,9 +160,11 @@ def write_mds(iterator: Iterable): raise RuntimeError('TaskContext.get() returns None') if mds_path[1] == '': - output = (os.path.join(mds_path[0], f'{id}'), '') + output = os.path.join(mds_path[0], f'{id}') + partition_path = (output, '') else: output = (os.path.join(mds_path[0], f'{id}'), os.path.join(mds_path[1], f'{id}')) + partition_path = output if mds_kwargs: kwargs = mds_kwargs.copy() @@ -194,8 +196,12 @@ def write_mds(iterator: Iterable): count += 1 yield pd.concat([ - pd.Series([os.path.join(output[0], get_index_basename())], name='mds_path_local'), - pd.Series([os.path.join(output[1], get_index_basename()) if output[1] != '' else ''], + pd.Series([os.path.join(partition_path[0], get_index_basename())], + name='mds_path_local'), + pd.Series([ + os.path.join(partition_path[1], get_index_basename()) + if partition_path[1] != '' else '' + ], name='mds_path_remote'), pd.Series([count], name='fail_count') ], @@ -249,8 +255,8 @@ def write_mds(iterator: Iterable): partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() if merge_index: - folder_urls = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] - do_merge_index(folder_urls, out, keep_local=keep_local, download_timeout=60) + index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] + do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 73cae6549..d72472675 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -443,11 +443,10 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: prefix = '' obj = urllib.parse.urlparse(self.remote) + bucket_name = obj.netloc if self.authentication == GCSAuthentication.HMAC: - bucket_name = obj.netloc prefix = os.path.join(str(obj.path).lstrip('/'), prefix) - paginator = self.gcs_client.get_paginator('list_objects_v2') pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix) try: @@ -455,10 +454,12 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: except KeyError: return [] elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT: - from google.cloud.storage import Blob, Bucket - - blob = Blob(str(obj.path).lstrip('/'), Bucket(self.gcs_client, obj.netloc)) - return [blob.name for blob in blob.bucket.list_blobs(prefix=prefix)] + #from google.cloud.storage import Blob, Bucket + #blob = Blob(str(obj.path).lstrip('/'), Bucket(self.gcs_client, obj.netloc)) + prefix = os.path.join(str(obj.path).lstrip('/'), prefix) + return [ + b.name for b in self.gcs_client.get_bucket(bucket_name).list_blobs(prefix=prefix) + ] class OCIUploader(CloudUploader): @@ -971,10 +972,10 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: List[str]: A list of object names that match the prefix. """ if prefix is None: - prefix = '.' - + prefix = '' ans = [] - for dirpath, _, files in os.walk(prefix): + print('I am here 100', os.path.join(self.local, prefix)) + for dirpath, _, files in os.walk(os.path.join(self.local, prefix)): for file in files: ans.append(os.path.join(dirpath, file)) return ans diff --git a/streaming/base/util.py b/streaming/base/util.py index c60433ab4..c2e1b7ca8 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -215,6 +215,56 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' +def merge_index(out: Union[str, Tuple[str, str]], *, keep_local: bool = True) -> None: + """Merge index.json given the root of MDS dataset. Write merged index to the root folder. + + Args: + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. + :a local directory, merge index happens locally + :a remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location + :a (local_dir, remote_dir), check if sub-directories index.json file present locally + If yes, then merge locally and upload to remote_dir . + If not, download all the sub-directories index.json from remote to local , merge locally, and upload to remote_dir . + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + """ + from streaming.base.storage.upload import CloudUploader + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out.""" + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + if not out: + logger.warning('No MDS dataset folder specified, no index merged') + return + + cu = CloudUploader.get(out, exist_ok=True, keep_local=True) + + local_index_files = [] + cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + for o in cl.list_objects(): + if o.endswith('.json') and not_merged_index(o, cu.local): + local_index_files.append(o) + + if cu.remote: + obj = urllib.parse.urlparse(cu.remote) + remote_index_files = [] + for o in cu.list_objects(): + if o.endswith(get_index_basename()) and not_merged_index(o, cu.remote): + remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) + if len(local_index_files) == len(remote_index_files): + do_merge_index(list(zip(local_index_files, remote_index_files)), + out, + keep_local=keep_local, + download_timeout=60) + else: + do_merge_index(remote_index_files, out, keep_local=keep_local, download_timeout=60) + return + + do_merge_index(local_index_files, out, keep_local=keep_local, download_timeout=60) + + def do_merge_index(index_file_urls: List[Any], out: Union[str, Tuple[str, str]], keep_local: bool = True, diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index b5756ca4d..a22052d07 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -21,7 +21,7 @@ 'gs://': 'mosaicml-composer-tests', 's3://': 'mosaicml-internal-temporary-composer-testing' } -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -31,6 +31,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) @@ -59,6 +61,17 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: except ImportError: raise ImportError('google.cloud.storage is not imported correctly.') + try: + import boto3 + s3 = boto3.client('s3') + response = s3.list_objects_v2(Bucket=MY_BUCKET['s3://'], Prefix=MY_PREFIX) + objects_to_delete = [{'Key': obj['Key']} for obj in response.get('Contents', [])] + if objects_to_delete: + s3.delete_objects(Bucket=MY_BUCKET['s3://'], + Delete={'Objects': objects_to_delete}) + except ImportError: + raise ImportError('boto3 is not imported correctly.') + class TestDataFrameToMDS: diff --git a/tests/test_util.py b/tests/test_util.py index dd66c44cf..eb856a569 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -5,6 +5,7 @@ import os import shutil import tempfile +import time from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from typing import Any, List, Optional, Tuple, Union @@ -15,14 +16,14 @@ from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, - get_list_arg, number_abbrev_to_int, retry) + get_list_arg, merge_index, number_abbrev_to_int, retry) -MY_PREFIX = 'train' +MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { 'gs://': 'mosaicml-composer-tests', 's3://': 'mosaicml-internal-temporary-composer-testing' } -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -32,6 +33,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) @@ -60,6 +63,17 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: except ImportError: raise ImportError('google.cloud.storage is not imported correctly.') + try: + import boto3 + s3 = boto3.client('s3') + response = s3.list_objects_v2(Bucket=MY_BUCKET['s3://'], Prefix=MY_PREFIX) + objects_to_delete = [{'Key': obj['Key']} for obj in response.get('Contents', [])] + if objects_to_delete: + s3.delete_objects(Bucket=MY_BUCKET['s3://'], + Delete={'Objects': objects_to_delete}) + except ImportError: + raise ImportError('boto3 is not imported correctly.') + @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) @@ -195,7 +209,9 @@ def get_expected(mds_root: str): assert not os.path.exists(os.path.join(cu.local, 'index.json')) return - assert os.path.exists(local_merged_index_path) + assert os.path.exists( + local_merged_index_path + ), f'{local_merged_index_path} does not exist when keep_local is {keep_local}' merged_index = json.load(open(local_merged_index_path, 'r')) n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) assert (n_shard_files == expected_n_shard_files @@ -204,11 +220,11 @@ def get_expected(mds_root: str): @pytest.mark.parametrize('scheme', ['gs://', 's3://']) @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3, 4, 5]) -@pytest.mark.parametrize('output_format', ['local', 'remote', 'tuple']) +@pytest.mark.parametrize('out_format', ['local', 'remote', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, - index_file_urls_pattern: int, output_format: str, scheme: str): + index_file_urls_pattern: int, out_format: str, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error @@ -217,11 +233,11 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, 5. All urls are str (remote) -> download all """ - if output_format != 'local': + if out_format != 'local': if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - if output_format == 'remote': + if out_format == 'remote': out = manual_integration_dir(scheme)[1] else: out = manual_integration_dir(scheme) @@ -293,6 +309,49 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) +@pytest.mark.parametrize('scheme', ['gs://', 's3://']) +@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) +@pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) +@pytest.mark.parametrize('keep_local', [False, True]) +def test_merge_index(manual_integration_dir: Any, out_format: str, n_partitions: int, + keep_local: bool, scheme: str): + from decimal import Decimal + + from pyspark.sql import SparkSession + from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + + from streaming.base.converters import dataframeToMDS + + if out_format == 'remote' or out_format == 'tuple': + if not MANUAL_INTEGRATION_TEST: + pytest.skip('Require cloud credentials. ' + + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') + if out_format == 'remote': + _, out = manual_integration_dir(scheme) + else: + out = manual_integration_dir(scheme) + else: + out, _ = manual_integration_dir(scheme) + + spark = SparkSession.builder.getOrCreate() # pyright: ignore + schema = StructType([ + StructField('id', IntegerType(), nullable=False), + StructField('name', StringType(), nullable=False), + StructField('amount', DecimalType(10, 2), nullable=False) + ]) + + data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), + (3, 'Charlie', Decimal('987.65'))] + + df = spark.createDataFrame(data=data, schema=schema).repartition(n_partitions) + + mds_kwargs = {'out': out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': keep_local} + + mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + merge_index(mds_path, keep_local=keep_local) + integrity_check(mds_path, keep_local=keep_local) + + @pytest.mark.parametrize('with_args', [True, False]) def test_retry(with_args: bool): num_tries = 0 From 36dff131b308e0ebc1633e815eca9bb198093b32 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 4 Oct 2023 00:36:18 -0700 Subject: [PATCH 28/59] remove materialized test dataset --- .../base/converters/test_dataframe_to_mds.py | 2 - .../resources/naive_MDSdataset/25/index.json | 31 ------ .../naive_MDSdataset/25/shard.00000.mds | Bin 244 -> 0 bytes .../resources/naive_MDSdataset/26/index.json | 31 ------ .../naive_MDSdataset/26/shard.00000.mds | Bin 266 -> 0 bytes .../resources/naive_MDSdataset/27/index.json | 4 - tests/resources/naive_MDSdataset/index.json | 57 ----------- tests/test_util.py | 94 ++++++++++-------- 8 files changed, 54 insertions(+), 165 deletions(-) delete mode 100644 tests/resources/naive_MDSdataset/25/index.json delete mode 100644 tests/resources/naive_MDSdataset/25/shard.00000.mds delete mode 100644 tests/resources/naive_MDSdataset/26/index.json delete mode 100644 tests/resources/naive_MDSdataset/26/shard.00000.mds delete mode 100644 tests/resources/naive_MDSdataset/27/index.json delete mode 100644 tests/resources/naive_MDSdataset/index.json diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a22052d07..ed4f75358 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -31,8 +31,6 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) diff --git a/tests/resources/naive_MDSdataset/25/index.json b/tests/resources/naive_MDSdataset/25/index.json deleted file mode 100644 index b7c1f591f..000000000 --- a/tests/resources/naive_MDSdataset/25/index.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "shards": [ - { - "column_encodings": [ - "str", - "str" - ], - "column_names": [ - "dept", - "id" - ], - "column_sizes": [ - null, - null - ], - "compression": null, - "format": "mds", - "hashes": [], - "raw_data": { - "basename": "shard.00000.mds", - "bytes": 244, - "hashes": {} - }, - "samples": 2, - "size_limit": 67108864, - "version": 2, - "zip_data": null - } - ], - "version": 2 -} diff --git a/tests/resources/naive_MDSdataset/25/shard.00000.mds b/tests/resources/naive_MDSdataset/25/shard.00000.mds deleted file mode 100644 index c776992af71090060606221dac217c624c93070d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 244 zcmYk0Jr2S!423Hs25wPh?m&MiEq7pIM5v0IKqK-~#VG@b8*&QHzyYvH2Zqm<{hpt^ zNRs4*ypa#`V7=3mv7NNN6UttI?b0KI;8~Xb+6nvYvE0b03poZdD8c@8Q1__YN$V`7 z8dWoT380+C@Tjq~^M(hUnGrxy1BW4A(+x#+S{X%_dYiACrmk>*lYY)Ao-6!+iR`(* V%7DL@ZQd5NAr4$iD636beF0ktMydb+ diff --git a/tests/resources/naive_MDSdataset/26/index.json b/tests/resources/naive_MDSdataset/26/index.json deleted file mode 100644 index 7aac4a36b..000000000 --- a/tests/resources/naive_MDSdataset/26/index.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "shards": [ - { - "column_encodings": [ - "str", - "str" - ], - "column_names": [ - "dept", - "id" - ], - "column_sizes": [ - null, - null - ], - "compression": null, - "format": "mds", - "hashes": [], - "raw_data": { - "basename": "shard.00000.mds", - "bytes": 266, - "hashes": {} - }, - "samples": 3, - "size_limit": 67108864, - "version": 2, - "zip_data": null - } - ], - "version": 2 -} diff --git a/tests/resources/naive_MDSdataset/26/shard.00000.mds b/tests/resources/naive_MDSdataset/26/shard.00000.mds deleted file mode 100644 index 42d6c202b350ecba24425d681efdcb36d1518b2b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 266 zcmZ9HO$x#=5QQTmf~P21chNs$@Bm)GrAP@)w1H$&lGcUd9X*7HaOuQ!AuhfdX5O0z zvm}H(kr(ntHVZ=Tv~y;%&?@Mh)Nl!OmmO&sJP6_&b-amDt d6|&bLpi?ztHT&B&Ma6maL=M8J&{SpFd;oURNsj;k diff --git a/tests/resources/naive_MDSdataset/27/index.json b/tests/resources/naive_MDSdataset/27/index.json deleted file mode 100644 index 16d98001e..000000000 --- a/tests/resources/naive_MDSdataset/27/index.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "shards": [], - "version": 2 -} diff --git a/tests/resources/naive_MDSdataset/index.json b/tests/resources/naive_MDSdataset/index.json deleted file mode 100644 index 2915cd61b..000000000 --- a/tests/resources/naive_MDSdataset/index.json +++ /dev/null @@ -1,57 +0,0 @@ -{ - "version": 2, - "shards": [ - { - "column_encodings": [ - "str", - "str" - ], - "column_names": [ - "dept", - "id" - ], - "column_sizes": [ - null, - null - ], - "compression": null, - "format": "mds", - "hashes": [], - "raw_data": { - "basename": "26/shard.00000.mds", - "bytes": 266, - "hashes": {} - }, - "samples": 3, - "size_limit": 67108864, - "version": 2, - "zip_data": null - }, - { - "column_encodings": [ - "str", - "str" - ], - "column_names": [ - "dept", - "id" - ], - "column_sizes": [ - null, - null - ], - "compression": null, - "format": "mds", - "hashes": [], - "raw_data": { - "basename": "25/shard.00000.mds", - "bytes": 244, - "hashes": {} - }, - "samples": 2, - "size_limit": 67108864, - "version": 2, - "zip_data": null - } - ] -} diff --git a/tests/test_util.py b/tests/test_util.py index eb856a569..6b6364672 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,6 +6,7 @@ import shutil import tempfile import time +import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory from typing import Any, List, Optional, Tuple, Union @@ -23,7 +24,7 @@ 'gs://': 'mosaicml-composer-tests', 's3://': 'mosaicml-internal-temporary-composer-testing' } -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -33,8 +34,6 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) @@ -232,43 +231,75 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, 4. All urls are tuple (local, remote). At least one url is not accessible locally -> download all 5. All urls are str (remote) -> download all """ + from decimal import Decimal + + from pyspark.sql import SparkSession + from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + + from streaming.base.converters import dataframeToMDS + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out.""" + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + local, remote = manual_integration_dir(scheme) if out_format != 'local': if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') if out_format == 'remote': - out = manual_integration_dir(scheme)[1] + out = remote else: - out = manual_integration_dir(scheme) + out = (local, remote) else: - out = manual_integration_dir(scheme)[0] + out = local - naive_mds_partitions = [ - 'tests/resources/naive_MDSdataset/25/', 'tests/resources/naive_MDSdataset/26', - 'tests/resources/naive_MDSdataset/27/' + spark = SparkSession.builder.getOrCreate() # pyright: ignore + schema = StructType([ + StructField('id', IntegerType(), nullable=False), + StructField('name', StringType(), nullable=False), + StructField('amount', DecimalType(10, 2), nullable=False) + ]) + data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), + (3, 'Charlie', Decimal('987.65'))] + df = spark.createDataFrame(data=data, schema=schema).repartition(3) + mds_kwargs = { + 'out': (local, remote), + 'columns': { + 'id': 'int', + 'name': 'str' + }, + 'keep_local': True + } + dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + + local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) + local_index_files = [ + o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) + ] + remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) + remote_index_files = [ + os.path.join(scheme, MY_BUCKET[scheme], o) + for o in remote_cu.list_objects() + if o.endswith('.json') and not_merged_index(o, remote) ] if index_file_urls_pattern == 1: - index_file_urls = [ - os.path.join(os.getcwd(), s, 'index.json') for s in naive_mds_partitions - ] - do_merge_index(index_file_urls, out, keep_local=keep_local) + do_merge_index(local_index_files, out, keep_local=keep_local) if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: index_file_urls = [ - os.path.join(a_temporary_folder, s, 'index.json') for s in naive_mds_partitions + os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] with pytest.raises(FileNotFoundError, match=f'.* does not exist or not accessible.*'): do_merge_index(index_file_urls, out, keep_local=keep_local) return if index_file_urls_pattern == 3: - index_file_urls = [] - for s in naive_mds_partitions: - index_file_urls.append((os.path.join(os.getcwd(), s, 'index.json'), - os.path.join(scheme, MY_BUCKET[scheme], s, 'index.json'))) + index_file_urls = list(zip(local_index_files, remote_index_files)) do_merge_index(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 4: @@ -277,34 +308,17 @@ def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') with tempfile.TemporaryDirectory() as a_temporary_folder: - index_file_urls = [] - for s in naive_mds_partitions: - cu_path = (os.path.join(os.getcwd(), - s), os.path.join(scheme, MY_BUCKET[scheme], s)) - cu = CloudUploader.get(cu_path, keep_local=True, exist_ok=True) - index_json = os.path.join(cu.local, 'index.json') - if os.path.exists(index_json): - cu.upload_file('index.json') - index_file_urls.append((os.path.join(a_temporary_folder, s, 'index.json'), - os.path.join(scheme, MY_BUCKET[scheme], s, 'index.json'))) + non_exist_local_files = [ + os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files + ] + index_file_urls = list(zip(non_exist_local_files, remote_index_files)) do_merge_index(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 5: if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - - with tempfile.TemporaryDirectory() as a_temporary_folder: - index_file_urls = [] - for s in naive_mds_partitions: - cu_path = (os.path.join(os.getcwd(), - s), os.path.join(scheme, MY_BUCKET[scheme], s)) - cu = CloudUploader.get(cu_path, keep_local=True, exist_ok=True) - index_json = os.path.join(cu.local, 'index.json') - if os.path.exists(index_json): - cu.upload_file('index.json') - index_file_urls.append(os.path.join(scheme, MY_BUCKET[scheme], s, 'index.json')) - do_merge_index(index_file_urls, out, keep_local=keep_local) + do_merge_index(remote_index_files, out, keep_local=keep_local) integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) From 691a588dd6dda2bb08b0d83e10a3e35abc7f6e51 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 4 Oct 2023 10:28:30 -0700 Subject: [PATCH 29/59] Change do_merge_index to merge_index_from_list --- streaming/base/converters/dataframe_to_mds.py | 4 +- streaming/base/util.py | 20 ++--- .../base/converters/test_dataframe_to_mds.py | 74 +++++++++---------- tests/test_util.py | 55 +++++++++----- 4 files changed, 85 insertions(+), 68 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index f654649d0..1e303edfb 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -11,7 +11,7 @@ import pandas as pd -from streaming.base.util import do_merge_index, get_import_exception_message +from streaming.base.util import merge_index_from_list, get_import_exception_message try: from pyspark import TaskContext @@ -256,7 +256,7 @@ def write_mds(iterator: Iterable): if merge_index: index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] - do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) + merge_index_from_list(index_files, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/streaming/base/util.py b/streaming/base/util.py index c2e1b7ca8..a4eadcd4c 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -254,21 +254,21 @@ def not_merged_index(index_file_path: str, out: str): if o.endswith(get_index_basename()) and not_merged_index(o, cu.remote): remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) if len(local_index_files) == len(remote_index_files): - do_merge_index(list(zip(local_index_files, remote_index_files)), - out, - keep_local=keep_local, - download_timeout=60) + merge_index_from_list(list(zip(local_index_files, remote_index_files)), + out, + keep_local=keep_local, + download_timeout=60) else: - do_merge_index(remote_index_files, out, keep_local=keep_local, download_timeout=60) + merge_index_from_list(remote_index_files, out, keep_local=keep_local, download_timeout=60) return - do_merge_index(local_index_files, out, keep_local=keep_local, download_timeout=60) + merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) -def do_merge_index(index_file_urls: List[Any], - out: Union[str, Tuple[str, str]], - keep_local: bool = True, - download_timeout: int = 60) -> None: +def merge_index_from_list(index_file_urls: List[Any], + out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: """Merge index.json from a list of index.json. Write to `out`, overwriting if exists. Args: diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index ed4f75358..ea5a86a48 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -21,7 +21,7 @@ 'gs://': 'mosaicml-composer-tests', 's3://': 'mosaicml-internal-temporary-composer-testing' } -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -131,16 +131,16 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: mds_kwargs=mds_kwargs) if keep_local: - assert (len(os.listdir(out)) > 0), f'{out} is empty' + assert len(os.listdir(out)) > 0, f'{out} is empty' for d in os.listdir(out): if os.path.isdir(os.path.join(out, d)): - assert (os.path.exists(os.path.join( - out, d, 'index.json'))), f'No index.json found in subdirectory {d}' + assert os.path.exists(os.path.join( + out, d, 'index.json')), f'No index.json found in subdirectory {d}' if merge_index: if keep_local: - assert (os.path.exists(os.path.join(out, - 'index.json'))), 'No merged index.json found' + assert os.path.exists(os.path.join(out, + 'index.json')), 'No merged index.json found' mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) nsamples = 0 for d in os.listdir(out): @@ -150,13 +150,13 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: 'r'))['shards'] if shards: nsamples += shards[0]['samples'] - assert (nsamples == sum([a['samples'] for a in mgi['shards']])) + assert nsamples == sum([a['samples'] for a in mgi['shards']]) if not keep_local: - assert (not os.path.exists(os.path.join( - out, 'index.json'))), 'merged index.json is found even keep_local = False' + assert not os.path.exists(os.path.join( + out, 'index.json')), 'merged index.json is found even keep_local = False' else: - assert not (os.path.exists(os.path.join( - out, 'index.json'))), 'merged index is created when merge_index=False' + assert not os.path.exists(os.path.join( + out, 'index.json')), 'merged index is created when merge_index=False' @pytest.mark.parametrize('use_columns', [True, False]) def test_end_to_end_conversion_local_decimal(self, decimal_dataframe: Any, use_columns: bool, @@ -169,7 +169,7 @@ def test_end_to_end_conversion_local_decimal(self, decimal_dataframe: Any, use_c mds_kwargs['columns'] = user_defined_columns _, _ = dataframeToMDS(decimal_dataframe, merge_index=True, mds_kwargs=mds_kwargs) - assert (len(os.listdir(out)) > 0), f'{out} is empty' + assert len(os.listdir(out)) > 0, f'{out} is empty' def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str, str]): out, _ = local_remote_dir @@ -210,16 +210,16 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer _, _ = dataframeToMDS(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs) if keep_local: - assert (len(os.listdir(out)) > 0), f'{out} is empty' + assert len(os.listdir(out)) > 0, f'{out} is empty' for d in os.listdir(out): if os.path.isdir(os.path.join(out, d)): - assert (os.path.exists(os.path.join( - out, d, 'index.json'))), f'No index.json found in subdirectory {d}' + assert os.path.exists(os.path.join( + out, d, 'index.json')), f'No index.json found in subdirectory {d}' if merge_index == True: if keep_local: - assert (os.path.exists(os.path.join(out, - 'index.json'))), 'No merged index.json found' + assert os.path.exists(os.path.join(out, + 'index.json')), 'No merged index.json found' mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) nsamples = 0 for d in os.listdir(out): @@ -229,13 +229,13 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer 'r'))['shards'] if shards: nsamples += shards[0]['samples'] - assert (nsamples == sum([a['samples'] for a in mgi['shards']])) + assert nsamples == sum([a['samples'] for a in mgi['shards']]) else: - assert (not os.path.exists(os.path.join( - out, 'index.json'))), 'merged index.json is found even keep_local=False' + assert not os.path.exists(os.path.join( + out, 'index.json')), 'merged index.json is found even keep_local=False' else: - assert not (os.path.exists(os.path.join( - out, 'index.json'))), 'merged index is created when merge_index=False' + assert not os.path.exists(os.path.join( + out, 'index.json')), 'merged index is created when merge_index=False' @pytest.mark.parametrize('scheme', ['gs://', 's3://']) @pytest.mark.parametrize('keep_local', [True]) # , False]) @@ -266,22 +266,22 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, merge_index=merge_index, mds_kwargs=mds_kwargs) - assert (fail_count == 0), 'some records were not converted correctly' + assert fail_count == 0, 'some records were not converted correctly' assert out == mds_path, f'returned mds_path: {mds_path} is not the same as out: {out}' if not keep_local: - assert (not os.path.exists(mds_path[0])), 'local folder were not removed' + assert not os.path.exists(mds_path[0]), 'local folder were not removed' return - assert (len(os.listdir(mds_path[0])) > 0), f'{mds_path[0]} is empty' + assert len(os.listdir(mds_path[0])) > 0, f'{mds_path[0]} is empty' for d in os.listdir(mds_path[0]): if os.path.isdir(os.path.join(mds_path[0], d)): - assert (os.path.exists(os.path.join( - mds_path[0], d, 'index.json'))), f'No index.json found in subdirectory {d}' + assert os.path.exists(os.path.join( + mds_path[0], d, 'index.json')), f'No index.json found in subdirectory {d}' if merge_index == True: - assert (os.path.exists(os.path.join(mds_path[0], - 'index.json'))), 'No merged index.json found' + assert os.path.exists(os.path.join(mds_path[0], + 'index.json')), 'No merged index.json found' else: assert not (os.path.exists(os.path.join( mds_path[0], 'index.json'))), 'merged index is created when merge_index=False' @@ -314,20 +314,20 @@ def test_integration_conversion_local_and_remote(self, dataframe: Any, assert out == mds_path, f'returned mds_path: {mds_path} is not the same as out: {out}' if not keep_local: - assert (not os.path.exists(mds_path[0])), 'local folder were not removed' + assert not os.path.exists(mds_path[0]), 'local folder were not removed' return - assert (len(os.listdir(mds_path[0])) > 0), f'{mds_path[0]} is empty' + assert len(os.listdir(mds_path[0])) > 0, f'{mds_path[0]} is empty' for d in os.listdir(mds_path[0]): if os.path.isdir(os.path.join(mds_path[0], d)): - assert (os.path.exists(os.path.join( - mds_path[0], d, 'index.json'))), f'No index.json found in subdirectory {d}' + assert os.path.exists(os.path.join( + mds_path[0], d, 'index.json')), f'No index.json found in subdirectory {d}' if merge_index == True: - assert (os.path.exists(os.path.join(mds_path[0], - 'index.json'))), 'No merged index.json found' + assert os.path.exists(os.path.join(mds_path[0], + 'index.json')), 'No merged index.json found' else: - assert not (os.path.exists(os.path.join(mds_path[0], 'index.json'))), ( + assert not os.path.exists(os.path.join(mds_path[0], 'index.json')), ( f'merged index is created at {mds_path[0]} when merge_index={merge_index} and ' + f'keep_local={keep_local}') @@ -351,7 +351,7 @@ def test_integration_conversion_remote_only(self, dataframe: Any, manual_integra assert len(mds_path) == 2, 'returned mds is a str but should be a tuple (local, remote)' assert not (os.path.exists(os.path.join( mds_path[0], 'index.json'))), 'Local merged index was not removed successfully' - assert (len(os.listdir(mds_path[0])) > 0), f'{mds_path[0]} is not empty' + assert len(os.listdir(mds_path[0])) > 0, f'{mds_path[0]} is not empty' def test_simple_remote(self, dataframe: Any): if not MANUAL_INTEGRATION_TEST: diff --git a/tests/test_util.py b/tests/test_util.py index 6b6364672..73b13a8b7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -16,7 +16,7 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, do_merge_index, +from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, merge_index_from_list, get_list_arg, merge_index, number_abbrev_to_int, retry) MY_PREFIX = 'train_' + str(time.time()) @@ -213,17 +213,16 @@ def get_expected(mds_root: str): ), f'{local_merged_index_path} does not exist when keep_local is {keep_local}' merged_index = json.load(open(local_merged_index_path, 'r')) n_shard_files = len({b['raw_data']['basename'] for b in merged_index['shards']}) - assert (n_shard_files == expected_n_shard_files - ), f'expected {expected_n_shard_files} shard files but got {n_shard_files}' + assert n_shard_files == expected_n_shard_files, f'expected {expected_n_shard_files} shard files but got {n_shard_files}' @pytest.mark.parametrize('scheme', ['gs://', 's3://']) @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3, 4, 5]) -@pytest.mark.parametrize('out_format', ['local', 'remote', 'tuple']) +@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_do_merge_index(manual_integration_dir: Any, keep_local: bool, - index_file_urls_pattern: int, out_format: str, scheme: str): +def test_merge_index_from_list(manual_integration_dir: Any, keep_local: bool, + index_file_urls_pattern: int, out_format: str, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error @@ -253,8 +252,9 @@ def not_merged_index(index_file_path: str, out: str): out = remote else: out = (local, remote) + mds_out = (local, remote) else: - out = local + mds_out = out = local spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ @@ -266,7 +266,7 @@ def not_merged_index(index_file_path: str, out: str): (3, 'Charlie', Decimal('987.65'))] df = spark.createDataFrame(data=data, schema=schema).repartition(3) mds_kwargs = { - 'out': (local, remote), + 'out': mds_out, 'columns': { 'id': 'int', 'name': 'str' @@ -279,15 +279,9 @@ def not_merged_index(index_file_path: str, out: str): local_index_files = [ o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) ] - remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) - remote_index_files = [ - os.path.join(scheme, MY_BUCKET[scheme], o) - for o in remote_cu.list_objects() - if o.endswith('.json') and not_merged_index(o, remote) - ] if index_file_urls_pattern == 1: - do_merge_index(local_index_files, out, keep_local=keep_local) + merge_index_from_list(local_index_files, out, keep_local=keep_local) if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: @@ -295,30 +289,53 @@ def not_merged_index(index_file_path: str, out: str): os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] with pytest.raises(FileNotFoundError, match=f'.* does not exist or not accessible.*'): - do_merge_index(index_file_urls, out, keep_local=keep_local) + merge_index_from_list(index_file_urls, out, keep_local=keep_local) return if index_file_urls_pattern == 3: + remote_index_files = [ + os.path.join(scheme, MY_BUCKET[scheme], MY_PREFIX, os.path.basename(o)) + for o in local_index_files + if o.endswith('.json') and not_merged_index(o, local) + ] index_file_urls = list(zip(local_index_files, remote_index_files)) - do_merge_index(index_file_urls, out, keep_local=keep_local) + merge_index_from_list(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 4: + if out_format == 'local': + return + if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') + remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) + remote_index_files = [ + os.path.join(scheme, MY_BUCKET[scheme], o) + for o in remote_cu.list_objects() + if o.endswith('.json') and not_merged_index(o, remote) + ] with tempfile.TemporaryDirectory() as a_temporary_folder: non_exist_local_files = [ os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] index_file_urls = list(zip(non_exist_local_files, remote_index_files)) - do_merge_index(index_file_urls, out, keep_local=keep_local) + merge_index_from_list(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 5: + if out_format == 'local': + return + if not MANUAL_INTEGRATION_TEST: pytest.skip('Require cloud credentials. ' + 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - do_merge_index(remote_index_files, out, keep_local=keep_local) + remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) + remote_index_files = [ + os.path.join(scheme, MY_BUCKET[scheme], o) + for o in remote_cu.list_objects() + if o.endswith('.json') and not_merged_index(o, remote) + ] + merge_index_from_list(remote_index_files, out, keep_local=keep_local) integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) From 05a8f324171134d203a2109b76ffe0dcfbb358e1 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 4 Oct 2023 11:12:22 -0700 Subject: [PATCH 30/59] Fix lints --- streaming/base/converters/dataframe_to_mds.py | 2 +- streaming/base/util.py | 5 ++++- tests/base/converters/test_dataframe_to_mds.py | 8 ++++---- tests/test_util.py | 15 ++++----------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 1e303edfb..789efb414 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -11,7 +11,7 @@ import pandas as pd -from streaming.base.util import merge_index_from_list, get_import_exception_message +from streaming.base.util import get_import_exception_message, merge_index_from_list try: from pyspark import TaskContext diff --git a/streaming/base/util.py b/streaming/base/util.py index a4eadcd4c..253cb34d2 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -259,7 +259,10 @@ def not_merged_index(index_file_path: str, out: str): keep_local=keep_local, download_timeout=60) else: - merge_index_from_list(remote_index_files, out, keep_local=keep_local, download_timeout=60) + merge_index_from_list(remote_index_files, + out, + keep_local=keep_local, + download_timeout=60) return merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index ea5a86a48..a35cc4bb6 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -140,7 +140,7 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: if merge_index: if keep_local: assert os.path.exists(os.path.join(out, - 'index.json')), 'No merged index.json found' + 'index.json')), 'No merged index.json found' mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) nsamples = 0 for d in os.listdir(out): @@ -219,7 +219,7 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer if merge_index == True: if keep_local: assert os.path.exists(os.path.join(out, - 'index.json')), 'No merged index.json found' + 'index.json')), 'No merged index.json found' mgi = json.load(open(os.path.join(out, 'index.json'), 'r')) nsamples = 0 for d in os.listdir(out): @@ -281,7 +281,7 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, if merge_index == True: assert os.path.exists(os.path.join(mds_path[0], - 'index.json')), 'No merged index.json found' + 'index.json')), 'No merged index.json found' else: assert not (os.path.exists(os.path.join( mds_path[0], 'index.json'))), 'merged index is created when merge_index=False' @@ -325,7 +325,7 @@ def test_integration_conversion_local_and_remote(self, dataframe: Any, if merge_index == True: assert os.path.exists(os.path.join(mds_path[0], - 'index.json')), 'No merged index.json found' + 'index.json')), 'No merged index.json found' else: assert not os.path.exists(os.path.join(mds_path[0], 'index.json')), ( f'merged index is created at {mds_path[0]} when merge_index={merge_index} and ' + diff --git a/tests/test_util.py b/tests/test_util.py index 73b13a8b7..96cbaebb6 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -16,8 +16,8 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, merge_index_from_list, - get_list_arg, merge_index, number_abbrev_to_int, retry) +from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, + merge_index, merge_index_from_list, number_abbrev_to_int, retry) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -218,7 +218,7 @@ def get_expected(mds_root: str): @pytest.mark.parametrize('scheme', ['gs://', 's3://']) @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3, 4, 5]) -@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) +@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) def test_merge_index_from_list(manual_integration_dir: Any, keep_local: bool, @@ -265,14 +265,7 @@ def not_merged_index(index_file_path: str, out: str): data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), (3, 'Charlie', Decimal('987.65'))] df = spark.createDataFrame(data=data, schema=schema).repartition(3) - mds_kwargs = { - 'out': mds_out, - 'columns': { - 'id': 'int', - 'name': 'str' - }, - 'keep_local': True - } + mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) From ed48f867934791d86a7682ddf527ecdfc7fea6e7 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 5 Oct 2023 10:45:22 -0700 Subject: [PATCH 31/59] Change merge_index to auto_merge_index to avoid duplicate naming --- streaming/base/converters/dataframe_to_mds.py | 4 +- streaming/base/util.py | 66 +++++++++---------- tests/test_util.py | 58 ++++++++++++---- 3 files changed, 76 insertions(+), 52 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 789efb414..44476cf56 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -11,7 +11,7 @@ import pandas as pd -from streaming.base.util import get_import_exception_message, merge_index_from_list +from streaming.base.util import get_import_exception_message, auto_merge_index try: from pyspark import TaskContext @@ -256,7 +256,7 @@ def write_mds(iterator: Iterable): if merge_index: index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] - merge_index_from_list(index_files, out, keep_local=keep_local, download_timeout=60) + auto_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/streaming/base/util.py b/streaming/base/util.py index 253cb34d2..42ac8d0f9 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -214,8 +214,16 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'To use {package_name} related packages with Streaming, run ' + \ f'`pip install \'mosaicml-streaming[{package_name}]\'`.' - -def merge_index(out: Union[str, Tuple[str, str]], *, keep_local: bool = True) -> None: +def auto_merge_index(*args, **kwargs): + if isinstance(args[0], list): + return merge_index_from_list(*args, **kwargs) + elif len(args) + len(kwargs) in [2,3,4]: + return merge_index_from_root(*args, **kwargs) + raise ValueError(f"Invalid arguments to merge_index: {args}, {kwargs}") + +def merge_index_from_root(out: Union[str, + Tuple[str, str]], + keep_local: bool = True) -> None: """Merge index.json given the root of MDS dataset. Write merged index to the root folder. Args: @@ -255,14 +263,14 @@ def not_merged_index(index_file_path: str, out: str): remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) if len(local_index_files) == len(remote_index_files): merge_index_from_list(list(zip(local_index_files, remote_index_files)), - out, - keep_local=keep_local, - download_timeout=60) + out, + keep_local=keep_local, + download_timeout=60) else: merge_index_from_list(remote_index_files, - out, - keep_local=keep_local, - download_timeout=60) + out, + keep_local=keep_local, + download_timeout=60) return merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) @@ -309,44 +317,30 @@ def merge_index_from_list(index_file_urls: List[Any], else: urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) - # Determine if we need to call download_file. - download = False - for url in urls: - if isinstance(url, tuple): - # If driver cannot access the local path, download = True - download = not os.path.exists(url[0]) - else: - # If url is a remote, download = True, False otherwise - download = urllib.parse.urlparse(url).scheme != '' - - # As long as one index file needs download, we download them all to keep it simple - if download: - break - # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: logging.warning(f'Create a temporary folder {temp_root} to store index files') - # container for absolute local folder path + # Copy files to a temporary directory. Download if necessary partitions = [] for url in urls: if isinstance(url, tuple): - local, remote = url + src = url[0] if os.path.exists(url[0]) else url[1] else: - local = remote = url + src = url - if download: - # If download is needed, download url from remote to temp_root - path = urllib.parse.urlparse(remote).path - local = os.path.join(temp_root, path.lstrip('/')) - try: - download_file(remote, local, download_timeout) - except Exception as ex: - raise RuntimeError(f'Failed to download index.json: {remote}') from ex + path = urllib.parse.urlparse(src).path + dest = os.path.join(temp_root, path.lstrip('/')) + + try: + download_file(src, dest, download_timeout) + except Exception as ex: + raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex + + if not (os.path.exists(dest)): + raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') - if not (os.path.exists(local)): - raise FileNotFoundError(f'Index file {local} does not exist or not accessible.') - partitions.append(local) + partitions.append(dest) # merge index files into shards shards = [] diff --git a/tests/test_util.py b/tests/test_util.py index 96cbaebb6..8838f3708 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -17,14 +17,15 @@ from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - merge_index, merge_index_from_list, number_abbrev_to_int, retry) + auto_merge_index, number_abbrev_to_int, retry) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { 'gs://': 'mosaicml-composer-tests', - 's3://': 'mosaicml-internal-temporary-composer-testing' + 's3://': 'mosaicml-internal-temporary-composer-testing', + 'oci://': 'mosaicml-internal-checkpoints', } -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -33,7 +34,7 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' # 'path/to/gooogle_api_credential.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) @@ -74,6 +75,35 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: raise ImportError('boto3 is not imported correctly.') + try: + import oci + config = oci.config.from_file() + client = oci.object_storage.ObjectStorageClient(config) + bucket_name = MY_BUCKET['oci://'] + prefix_to_delete = MY_PREFIX + # List objects with the specified prefix + response = client.list_objects( + namespace_name=client.get_namespace().data, + bucket_name=bucket_name, + fields=["name"], + prefix=prefix_to_delete, + ) + + # Delete the objects + for obj in response.data.objects: + object_name = obj.name + client.delete_object( + namespace_name=client.get_namespace().data, + bucket_name=bucket_name, + object_name=object_name, + ) + print(f"Deleted: {object_name}") + + print(f"Deleted {len(response.data.objects)} objects with prefix: {prefix_to_delete}") + + except ImportError: + raise ImportError('boto3 is not imported correctly.') + @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) def test_get_list_arg(text: str, expected_output: List[Optional[str]]): @@ -216,7 +246,7 @@ def get_expected(mds_root: str): assert n_shard_files == expected_n_shard_files, f'expected {expected_n_shard_files} shard files but got {n_shard_files}' -@pytest.mark.parametrize('scheme', ['gs://', 's3://']) +@pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3, 4, 5]) @pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @@ -274,15 +304,15 @@ def not_merged_index(index_file_path: str, out: str): ] if index_file_urls_pattern == 1: - merge_index_from_list(local_index_files, out, keep_local=keep_local) + auto_merge_index(local_index_files, out, keep_local=keep_local) if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: index_file_urls = [ os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] - with pytest.raises(FileNotFoundError, match=f'.* does not exist or not accessible.*'): - merge_index_from_list(index_file_urls, out, keep_local=keep_local) + with pytest.raises(RuntimeError, match=f'.*Failed to download index.json.*'): + auto_merge_index(index_file_urls, out, keep_local=keep_local) return if index_file_urls_pattern == 3: @@ -292,7 +322,7 @@ def not_merged_index(index_file_path: str, out: str): if o.endswith('.json') and not_merged_index(o, local) ] index_file_urls = list(zip(local_index_files, remote_index_files)) - merge_index_from_list(index_file_urls, out, keep_local=keep_local) + auto_merge_index(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 4: if out_format == 'local': @@ -313,7 +343,7 @@ def not_merged_index(index_file_path: str, out: str): os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] index_file_urls = list(zip(non_exist_local_files, remote_index_files)) - merge_index_from_list(index_file_urls, out, keep_local=keep_local) + auto_merge_index(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 5: if out_format == 'local': @@ -328,16 +358,16 @@ def not_merged_index(index_file_path: str, out: str): for o in remote_cu.list_objects() if o.endswith('.json') and not_merged_index(o, remote) ] - merge_index_from_list(remote_index_files, out, keep_local=keep_local) + auto_merge_index(remote_index_files, out, keep_local=keep_local) integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) -@pytest.mark.parametrize('scheme', ['gs://', 's3://']) +@pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) -def test_merge_index(manual_integration_dir: Any, out_format: str, n_partitions: int, +def test_merge_index_from_root(manual_integration_dir: Any, out_format: str, n_partitions: int, keep_local: bool, scheme: str): from decimal import Decimal @@ -372,7 +402,7 @@ def test_merge_index(manual_integration_dir: Any, out_format: str, n_partitions: mds_kwargs = {'out': out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': keep_local} mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - merge_index(mds_path, keep_local=keep_local) + auto_merge_index(mds_path, keep_local=keep_local) integrity_check(mds_path, keep_local=keep_local) From 34799a814a59b4d618e948ea8da7f661be5d2171 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 5 Oct 2023 13:03:38 -0700 Subject: [PATCH 32/59] update pytest yaml --- .github/workflows/pytest.yaml | 15 +++++++------ .../base/converters/test_dataframe_to_mds.py | 22 +++++++++++++++++++ tests/test_util.py | 21 ++++++------------ 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml index bd9518611..4a2d49d9d 100644 --- a/.github/workflows/pytest.yaml +++ b/.github/workflows/pytest.yaml @@ -47,10 +47,11 @@ jobs: id: tests run: | set -ex - pytest --splits 7 --group 1 --cov-fail-under=10 - pytest --splits 7 --group 2 --cov-fail-under=10 - pytest --splits 7 --group 3 --cov-fail-under=10 - pytest --splits 7 --group 4 --cov-fail-under=10 - pytest --splits 7 --group 5 --cov-fail-under=10 - pytest --splits 7 --group 6 --cov-fail-under=10 - pytest --splits 7 --group 7 --cov-fail-under=10 + pytest --splits 8 --group 1 --cov-fail-under=10 + pytest --splits 8 --group 2 --cov-fail-under=10 + pytest --splits 8 --group 3 --cov-fail-under=10 + pytest --splits 8 --group 4 --cov-fail-under=10 + pytest --splits 8 --group 5 --cov-fail-under=10 + pytest --splits 8 --group 6 --cov-fail-under=10 + pytest --splits 8 --group 7 --cov-fail-under=10 + pytest --splits 8 --group 8 --cov-fail-under=10 diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index a35cc4bb6..f19137616 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -70,6 +70,28 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: except ImportError: raise ImportError('boto3 is not imported correctly.') + try: + import oci + client = oci.object_storage.ObjectStorageClient(oci.config.from_file()) + response = client.list_objects( + namespace_name=client.get_namespace().data, + bucket_name=MY_BUCKET['oci://'], + fields=["name"], + prefix=MY_PREFIX, + ) + + # Delete the objects + for obj in response.data.objects: + client.delete_object( + namespace_name=client.get_namespace().data, + bucket_name=bucket_name, + object_name=obj.name, + ) + print(f"Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}") + + except ImportError: + raise ImportError('boto3 is not imported correctly.') + class TestDataFrameToMDS: diff --git a/tests/test_util.py b/tests/test_util.py index 8838f3708..39ab52674 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -74,36 +74,29 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: except ImportError: raise ImportError('boto3 is not imported correctly.') - try: import oci - config = oci.config.from_file() - client = oci.object_storage.ObjectStorageClient(config) - bucket_name = MY_BUCKET['oci://'] - prefix_to_delete = MY_PREFIX - # List objects with the specified prefix + client = oci.object_storage.ObjectStorageClient(oci.config.from_file()) response = client.list_objects( namespace_name=client.get_namespace().data, - bucket_name=bucket_name, + bucket_name=MY_BUCKET['oci://'], fields=["name"], - prefix=prefix_to_delete, + prefix=MY_PREFIX, ) # Delete the objects for obj in response.data.objects: - object_name = obj.name client.delete_object( namespace_name=client.get_namespace().data, - bucket_name=bucket_name, - object_name=object_name, + bucket_name=MY_BUCKET['oci://'], + object_name=obj.name, ) - print(f"Deleted: {object_name}") - - print(f"Deleted {len(response.data.objects)} objects with prefix: {prefix_to_delete}") + print(f"Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}") except ImportError: raise ImportError('boto3 is not imported correctly.') + @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) def test_get_list_arg(text: str, expected_output: List[Optional[str]]): From 85a4a6db0c0adf9ff4039ea40474ab52e92cb3f5 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 5 Oct 2023 13:22:05 -0700 Subject: [PATCH 33/59] update --- tests/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index 39ab52674..e7f6aaec5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -25,7 +25,7 @@ 's3://': 'mosaicml-internal-temporary-composer-testing', 'oci://': 'mosaicml-internal-checkpoints', } -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls From 8a4d43b1fbb00f3fc12d5760cf7377bb0ca0034a Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 5 Oct 2023 13:55:22 -0700 Subject: [PATCH 34/59] update --- tests/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index e7f6aaec5..a3cea795b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -353,7 +353,7 @@ def not_merged_index(index_file_path: str, out: str): ] auto_merge_index(remote_index_files, out, keep_local=keep_local) - integrity_check(out, keep_local=keep_local, expected_n_shard_files=2) + integrity_check(out, keep_local=keep_local) @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) From 5f1a63b1729be810f55edfd68348684f0fd49ab6 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 5 Oct 2023 14:45:50 -0700 Subject: [PATCH 35/59] Fix lints --- streaming/base/converters/dataframe_to_mds.py | 2 +- streaming/base/util.py | 24 +++++++++---------- .../base/converters/test_dataframe_to_mds.py | 6 ++--- tests/test_util.py | 13 +++++----- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 44476cf56..a904c1f33 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -11,7 +11,7 @@ import pandas as pd -from streaming.base.util import get_import_exception_message, auto_merge_index +from streaming.base.util import auto_merge_index, get_import_exception_message try: from pyspark import TaskContext diff --git a/streaming/base/util.py b/streaming/base/util.py index 42ac8d0f9..ac48a57e7 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -214,16 +214,16 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'To use {package_name} related packages with Streaming, run ' + \ f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def auto_merge_index(*args, **kwargs): + +def auto_merge_index(*args, **kwargs): # pyright: ignore if isinstance(args[0], list): return merge_index_from_list(*args, **kwargs) - elif len(args) + len(kwargs) in [2,3,4]: + elif len(args) + len(kwargs) in [2, 3, 4]: return merge_index_from_root(*args, **kwargs) - raise ValueError(f"Invalid arguments to merge_index: {args}, {kwargs}") + raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') + -def merge_index_from_root(out: Union[str, - Tuple[str, str]], - keep_local: bool = True) -> None: +def merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = True) -> None: """Merge index.json given the root of MDS dataset. Write merged index to the root folder. Args: @@ -263,14 +263,14 @@ def not_merged_index(index_file_path: str, out: str): remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) if len(local_index_files) == len(remote_index_files): merge_index_from_list(list(zip(local_index_files, remote_index_files)), - out, - keep_local=keep_local, - download_timeout=60) + out, + keep_local=keep_local, + download_timeout=60) else: merge_index_from_list(remote_index_files, - out, - keep_local=keep_local, - download_timeout=60) + out, + keep_local=keep_local, + download_timeout=60) return merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index f19137616..9a0351a02 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -76,7 +76,7 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: response = client.list_objects( namespace_name=client.get_namespace().data, bucket_name=MY_BUCKET['oci://'], - fields=["name"], + fields=['name'], prefix=MY_PREFIX, ) @@ -84,10 +84,10 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: for obj in response.data.objects: client.delete_object( namespace_name=client.get_namespace().data, - bucket_name=bucket_name, + bucket_name=MY_BUCKET['oci://'], object_name=obj.name, ) - print(f"Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}") + print(f'Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}') except ImportError: raise ImportError('boto3 is not imported correctly.') diff --git a/tests/test_util.py b/tests/test_util.py index a3cea795b..8d4ad32ab 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -16,8 +16,8 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, - auto_merge_index, number_abbrev_to_int, retry) +from streaming.base.util import (auto_merge_index, bytes_to_int, clean_stale_shared_memory, + get_list_arg, number_abbrev_to_int, retry) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -34,7 +34,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' # 'path/to/gooogle_api_credential.json' + os.environ[ + 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' # 'path/to/gooogle_api_credential.json' os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) @@ -80,7 +81,7 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: response = client.list_objects( namespace_name=client.get_namespace().data, bucket_name=MY_BUCKET['oci://'], - fields=["name"], + fields=['name'], prefix=MY_PREFIX, ) @@ -91,7 +92,7 @@ def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: bucket_name=MY_BUCKET['oci://'], object_name=obj.name, ) - print(f"Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}") + print(f'Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}') except ImportError: raise ImportError('boto3 is not imported correctly.') @@ -361,7 +362,7 @@ def not_merged_index(index_file_path: str, out: str): @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) def test_merge_index_from_root(manual_integration_dir: Any, out_format: str, n_partitions: int, - keep_local: bool, scheme: str): + keep_local: bool, scheme: str): from decimal import Decimal from pyspark.sql import SparkSession From 320fa8d6846dd1007f2090352c79660d6a4b237d Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Thu, 5 Oct 2023 22:15:56 -0700 Subject: [PATCH 36/59] Make merge_index a wrapper --- streaming/base/converters/dataframe_to_mds.py | 5 +- streaming/base/util.py | 125 +++++++++--------- tests/test_util.py | 16 +-- 3 files changed, 74 insertions(+), 72 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index a904c1f33..a4ead0a47 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -11,7 +11,8 @@ import pandas as pd -from streaming.base.util import auto_merge_index, get_import_exception_message +from streaming.base.util import get_import_exception_message +from streaming.base.util import merge_index as do_merge_index try: from pyspark import TaskContext @@ -256,7 +257,7 @@ def write_mds(iterator: Iterable): if merge_index: index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] - auto_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) + do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) if cu.remote is not None: if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: diff --git a/streaming/base/util.py b/streaming/base/util.py index ac48a57e7..2a35fc842 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -215,71 +215,19 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def auto_merge_index(*args, **kwargs): # pyright: ignore - if isinstance(args[0], list): - return merge_index_from_list(*args, **kwargs) - elif len(args) + len(kwargs) in [2, 3, 4]: - return merge_index_from_root(*args, **kwargs) +def merge_index(*args, **kwargs): # pyright: ignore + if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: + return _merge_index_from_list(*args, **kwargs) + elif (isinstance(args[0], str) or + isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2]: + return _merge_index_from_root(*args, **kwargs) raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') -def merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = True) -> None: - """Merge index.json given the root of MDS dataset. Write merged index to the root folder. - - Args: - out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. - :a local directory, merge index happens locally - :a remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location - :a (local_dir, remote_dir), check if sub-directories index.json file present locally - If yes, then merge locally and upload to remote_dir . - If not, download all the sub-directories index.json from remote to local , merge locally, and upload to remote_dir . - keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. - """ - from streaming.base.storage.upload import CloudUploader - - def not_merged_index(index_file_path: str, out: str): - """Check if index_file_path is the merged index at folder out.""" - prefix = str(urllib.parse.urlparse(out).path) - return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') - - if not out: - logger.warning('No MDS dataset folder specified, no index merged') - return - - cu = CloudUploader.get(out, exist_ok=True, keep_local=True) - - local_index_files = [] - cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) - for o in cl.list_objects(): - if o.endswith('.json') and not_merged_index(o, cu.local): - local_index_files.append(o) - - if cu.remote: - obj = urllib.parse.urlparse(cu.remote) - remote_index_files = [] - for o in cu.list_objects(): - if o.endswith(get_index_basename()) and not_merged_index(o, cu.remote): - remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) - if len(local_index_files) == len(remote_index_files): - merge_index_from_list(list(zip(local_index_files, remote_index_files)), - out, - keep_local=keep_local, - download_timeout=60) - else: - merge_index_from_list(remote_index_files, - out, - keep_local=keep_local, - download_timeout=60) - return - - merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) - - -def merge_index_from_list(index_file_urls: List[Any], - out: Union[str, Tuple[str, str]], - keep_local: bool = True, - download_timeout: int = 60) -> None: +def _merge_index_from_list(index_file_urls: List[Any], + out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: """Merge index.json from a list of index.json. Write to `out`, overwriting if exists. Args: @@ -376,6 +324,59 @@ def merge_index_from_list(index_file_urls: List[Any], shutil.rmtree(cu.local, ignore_errors=True) +def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = True) -> None: + """Merge index.json given the root of MDS dataset. Write merged index to the root folder. + + Args: + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. + :a local directory, merge index happens locally + :a remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location + :a (local_dir, remote_dir), check if sub-directories index.json file present locally + If yes, then merge locally and upload to remote_dir . + If not, download all the sub-directories index.json from remote to local , merge locally, and upload to remote_dir . + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + """ + from streaming.base.storage.upload import CloudUploader + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out.""" + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + if not out: + logger.warning('No MDS dataset folder specified, no index merged') + return + + cu = CloudUploader.get(out, exist_ok=True, keep_local=True) + + local_index_files = [] + cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + for o in cl.list_objects(): + if o.endswith('.json') and not_merged_index(o, cu.local): + local_index_files.append(o) + + if cu.remote: + obj = urllib.parse.urlparse(cu.remote) + remote_index_files = [] + for o in cu.list_objects(): + if o.endswith(get_index_basename()) and not_merged_index(o, cu.remote): + remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) + if len(local_index_files) == len(remote_index_files): + _merge_index_from_list(list(zip(local_index_files, remote_index_files)), + out, + keep_local=keep_local, + download_timeout=60) + else: + _merge_index_from_list(remote_index_files, + out, + keep_local=keep_local, + download_timeout=60) + return + + _merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) + + @overload def retry( exc_class: Union[Type[Exception], Sequence[Type[Exception]]] = ..., diff --git a/tests/test_util.py b/tests/test_util.py index 8d4ad32ab..419400b41 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -16,8 +16,8 @@ from streaming.base.shared.prefix import _get_path from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader -from streaming.base.util import (auto_merge_index, bytes_to_int, clean_stale_shared_memory, - get_list_arg, number_abbrev_to_int, retry) +from streaming.base.util import (bytes_to_int, clean_stale_shared_memory, get_list_arg, + merge_index, number_abbrev_to_int, retry) MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { @@ -298,7 +298,7 @@ def not_merged_index(index_file_path: str, out: str): ] if index_file_urls_pattern == 1: - auto_merge_index(local_index_files, out, keep_local=keep_local) + merge_index(local_index_files, out, keep_local=keep_local) if index_file_urls_pattern == 2: with tempfile.TemporaryDirectory() as a_temporary_folder: @@ -306,7 +306,7 @@ def not_merged_index(index_file_path: str, out: str): os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] with pytest.raises(RuntimeError, match=f'.*Failed to download index.json.*'): - auto_merge_index(index_file_urls, out, keep_local=keep_local) + merge_index(index_file_urls, out, keep_local=keep_local) return if index_file_urls_pattern == 3: @@ -316,7 +316,7 @@ def not_merged_index(index_file_path: str, out: str): if o.endswith('.json') and not_merged_index(o, local) ] index_file_urls = list(zip(local_index_files, remote_index_files)) - auto_merge_index(index_file_urls, out, keep_local=keep_local) + merge_index(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 4: if out_format == 'local': @@ -337,7 +337,7 @@ def not_merged_index(index_file_path: str, out: str): os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files ] index_file_urls = list(zip(non_exist_local_files, remote_index_files)) - auto_merge_index(index_file_urls, out, keep_local=keep_local) + merge_index(index_file_urls, out, keep_local=keep_local) if index_file_urls_pattern == 5: if out_format == 'local': @@ -352,7 +352,7 @@ def not_merged_index(index_file_path: str, out: str): for o in remote_cu.list_objects() if o.endswith('.json') and not_merged_index(o, remote) ] - auto_merge_index(remote_index_files, out, keep_local=keep_local) + merge_index(remote_index_files, out, keep_local=keep_local) integrity_check(out, keep_local=keep_local) @@ -396,7 +396,7 @@ def test_merge_index_from_root(manual_integration_dir: Any, out_format: str, n_p mds_kwargs = {'out': out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': keep_local} mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - auto_merge_index(mds_path, keep_local=keep_local) + merge_index(mds_path, keep_local=keep_local) integrity_check(mds_path, keep_local=keep_local) From 22a9cc4d3e399140e655588b4bfac7dcdda5a53f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 11:16:41 -0700 Subject: [PATCH 37/59] add print --- streaming/base/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/streaming/base/util.py b/streaming/base/util.py index 2a35fc842..e1a2a714e 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -251,6 +251,7 @@ def _merge_index_from_list(index_file_urls: List[Any], logger.warning('Need to specify both index_file_urls and out. No index merged') return + print('index_file_urls = ', index_file_urls) # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() From f8afb662ad436983c51a60c7d653f5242f34b2b0 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 11:31:54 -0700 Subject: [PATCH 38/59] Change fail msg for missing local file and invalid remote url --- streaming/base/util.py | 5 ++++- tests/test_util.py | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index e1a2a714e..5c9757a86 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -278,7 +278,10 @@ def _merge_index_from_list(index_file_urls: List[Any], else: src = url - path = urllib.parse.urlparse(src).path + obj = urllib.parse.urlparse(src) + scheme, bucket, path = obj.scheme, obj.netloc, obj.path + if scheme == '' and bucket=='' and path == '': + raise FileNotFoundError("Check data availability! url[0] is not accessible. url[1] does not have a valid url format") dest = os.path.join(temp_root, path.lstrip('/')) try: diff --git a/tests/test_util.py b/tests/test_util.py index 419400b41..b7597151f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -307,6 +307,13 @@ def not_merged_index(index_file_path: str, out: str): ] with pytest.raises(RuntimeError, match=f'.*Failed to download index.json.*'): merge_index(index_file_urls, out, keep_local=keep_local) + + with tempfile.TemporaryDirectory() as a_temporary_folder: + index_file_urls = [ + (os.path.join(a_temporary_folder, os.path.basename(s)), '') for s in local_index_files + ] + with pytest.raises(FileNotFoundError, match=f'.*Check data availability!.*'): + merge_index(index_file_urls, out, keep_local=keep_local) return if index_file_urls_pattern == 3: From e9d82a12a6f38940d7a362d4108ba639f0ce2f1f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 11:36:03 -0700 Subject: [PATCH 39/59] update msg --- streaming/base/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 5c9757a86..2ac2397bb 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -281,7 +281,7 @@ def _merge_index_from_list(index_file_urls: List[Any], obj = urllib.parse.urlparse(src) scheme, bucket, path = obj.scheme, obj.netloc, obj.path if scheme == '' and bucket=='' and path == '': - raise FileNotFoundError("Check data availability! url[0] is not accessible. url[1] does not have a valid url format") + raise FileNotFoundError(f"Check data availability! local index {url[0]} is not accessible. remote index {url[1]} does not have a valid url format") dest = os.path.join(temp_root, path.lstrip('/')) try: From e0e0343914240ce3a76268dd437552babf88f47d Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 11:40:08 -0700 Subject: [PATCH 40/59] remove print --- streaming/base/util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 2ac2397bb..ee8e9b109 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -251,7 +251,6 @@ def _merge_index_from_list(index_file_urls: List[Any], logger.warning('Need to specify both index_file_urls and out. No index merged') return - print('index_file_urls = ', index_file_urls) # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() From 6b0e6d80e76309f3f3108fe944c5ebcd6b66a736 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 12:45:54 -0700 Subject: [PATCH 41/59] Fix lints --- streaming/base/util.py | 6 ++++-- tests/test_util.py | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index ee8e9b109..afe141781 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -279,8 +279,10 @@ def _merge_index_from_list(index_file_urls: List[Any], obj = urllib.parse.urlparse(src) scheme, bucket, path = obj.scheme, obj.netloc, obj.path - if scheme == '' and bucket=='' and path == '': - raise FileNotFoundError(f"Check data availability! local index {url[0]} is not accessible. remote index {url[1]} does not have a valid url format") + if scheme == '' and bucket == '' and path == '': + raise FileNotFoundError( + f'Check data availability! local index {url[0]} is not accessible. remote index {url[1]} does not have a valid url format' + ) dest = os.path.join(temp_root, path.lstrip('/')) try: diff --git a/tests/test_util.py b/tests/test_util.py index b7597151f..b53237485 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -309,9 +309,8 @@ def not_merged_index(index_file_path: str, out: str): merge_index(index_file_urls, out, keep_local=keep_local) with tempfile.TemporaryDirectory() as a_temporary_folder: - index_file_urls = [ - (os.path.join(a_temporary_folder, os.path.basename(s)), '') for s in local_index_files - ] + index_file_urls = [(os.path.join(a_temporary_folder, os.path.basename(s)), '') + for s in local_index_files] with pytest.raises(FileNotFoundError, match=f'.*Check data availability!.*'): merge_index(index_file_urls, out, keep_local=keep_local) return From a95e34b48555e53fccaf776142c2da1b862ea61b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 22:47:48 -0700 Subject: [PATCH 42/59] Add warning msg for exist_ok=True --- streaming/base/storage/upload.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index d72472675..7a4404a02 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -164,8 +164,12 @@ def __init__(self, self.local = out[0] self.remote = out[1] - if not exist_ok and os.path.exists(self.local) and len(os.listdir(self.local)) != 0: - raise FileExistsError(f'Directory is not empty: {self.local}') + if os.path.exists(self.local) and len(os.listdir(self.local)) != 0: + if not exist_ok: + raise FileExistsError(f'Directory is not empty: {self.local}') + else: + logger.warning(f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.') + os.makedirs(self.local, exist_ok=True) def upload_file(self, filename: str): From 2de66e2bb90504b326367ea47ed40a2e772ec7ef Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 23:13:42 -0700 Subject: [PATCH 43/59] Address comments --- streaming/base/converters/dataframe_to_mds.py | 2 +- streaming/base/storage/upload.py | 103 +++++++++--------- streaming/base/util.py | 2 +- .../base/converters/test_dataframe_to_mds.py | 2 +- tests/test_util.py | 2 +- 5 files changed, 55 insertions(+), 56 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index a4ead0a47..7f0d3d548 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -160,7 +160,7 @@ def write_mds(iterator: Iterable): else: raise RuntimeError('TaskContext.get() returns None') - if mds_path[1] == '': + if mds_path[1] == '': # only local output = os.path.join(mds_path[0], f'{id}') partition_path = (output, '') else: diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 7a4404a02..e98a3a6d9 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -57,8 +57,8 @@ def get(cls, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> Any: + retry: int = 2, + exist_ok: bool = False) -> Any: """Instantiate a cloud provider uploader or a local uploader based on remote path. Args: @@ -74,8 +74,8 @@ def get(cls, shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. @@ -89,7 +89,7 @@ def get(cls, if prefix == 'dbfs:/Volumes': provider_prefix = prefix return getattr(sys.modules[__name__], - UPLOADERS[provider_prefix])(out, keep_local, progress_bar, exist_ok, retry) + UPLOADERS[provider_prefix])(out, keep_local, progress_bar, retry, exist_ok) def _validate(self, out: Union[str, Tuple[str, str]]) -> None: """Validate the `out` argument. @@ -123,8 +123,8 @@ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: + retry: int = 2, + exist_ok: bool = False) -> None: """Initialize and validate local and remote path. Args: @@ -140,8 +140,8 @@ def __init__(self, shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -181,13 +181,13 @@ def upload_file(self, filename: str): Raises: NotImplementedError: Override this method in your sub-class. """ - raise NotImplementedError('Override this method in your sub-class') + raise NotImplementedError(f'{type(self).__name__}.upload_file is not implemented') def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the object store with the given prefix. Args: - prefix (Optional[str], optional): The prefix to search for. Defaults to None. + prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``. Returns: List[str]: A list of object names that match the prefix. @@ -220,17 +220,17 @@ class S3Uploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) import boto3 from botocore.config import Config @@ -297,7 +297,7 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the S3 object store with the given prefix. Args: - prefix (Optional[str], optional): The prefix to search for. Defaults to None. + prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``. Returns: List[str]: A list of object names that match the prefix. @@ -333,17 +333,17 @@ class GCSUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: import boto3 @@ -482,17 +482,17 @@ class OCIUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) import oci @@ -561,7 +561,7 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: """List all objects in the OCI object store with the given prefix. Args: - prefix (Optional[str], optional): The prefix to search for. Defaults to None. + prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``. Returns: List[str]: A list of object names that match the prefix. @@ -607,17 +607,17 @@ class AzureUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) from azure.storage.blob import BlobServiceClient @@ -694,17 +694,17 @@ class AzureDataLakeUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) from azure.storage.filedatalake import DataLakeServiceClient @@ -778,17 +778,17 @@ class DatabricksUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) self.client = self._create_workspace_client() def _create_workspace_client(self): @@ -816,17 +816,17 @@ class DatabricksUnityCatalogUploader(DatabricksUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) def upload_file(self, filename: str): """Upload file from local instance to Databricks Unity Catalog. @@ -864,17 +864,17 @@ class DBFSUploader(DatabricksUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore self.check_folder_exists() @@ -933,17 +933,17 @@ class LocalUploader(CloudUploader): shard file or remove it after uploading. Defaults to ``False``. progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. + exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - exist_ok: bool = False, - retry: int = 2) -> None: - super().__init__(out, keep_local, progress_bar, exist_ok, retry) + retry: int = 2, + exist_ok: bool = False) -> None: + super().__init__(out, keep_local, progress_bar, retry, exist_ok) # Create remote directory if it doesn't exist if self.remote: os.makedirs(self.remote, exist_ok=True) @@ -970,16 +970,15 @@ def list_objects(self, prefix: Optional[str] = None) -> List[str]: """List all objects locally with the given prefix. Args: - prefix (Optional[str], optional): The prefix to search for. Defaults to None. + prefix (Optional[str], optional): The prefix to search for. Defaults to ``None``. Returns: List[str]: A list of object names that match the prefix. """ if prefix is None: prefix = '' - ans = [] - print('I am here 100', os.path.join(self.local, prefix)) + file_paths = [] for dirpath, _, files in os.walk(os.path.join(self.local, prefix)): for file in files: - ans.append(os.path.join(dirpath, file)) - return ans + file_paths.append(os.path.join(dirpath, file)) + return file_paths diff --git a/streaming/base/util.py b/streaming/base/util.py index afe141781..2661251f6 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -228,7 +228,7 @@ def _merge_index_from_list(index_file_urls: List[Any], out: Union[str, Tuple[str, str]], keep_local: bool = True, download_timeout: int = 60) -> None: - """Merge index.json from a list of index.json. Write to `out`, overwriting if exists. + """Merge index.json from a list of subset of MDS dataset index.json to create joined index file. Args: index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 9a0351a02..b86d84bc5 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -175,7 +175,7 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: assert nsamples == sum([a['samples'] for a in mgi['shards']]) if not keep_local: assert not os.path.exists(os.path.join( - out, 'index.json')), 'merged index.json is found even keep_local = False' + out, 'index.json')), 'merged index.json is found even through keep_local = False' else: assert not os.path.exists(os.path.join( out, 'index.json')), 'merged index is created when merge_index=False' diff --git a/tests/test_util.py b/tests/test_util.py index b53237485..f3a9c340b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -25,7 +25,7 @@ 's3://': 'mosaicml-internal-temporary-composer-testing', 'oci://': 'mosaicml-internal-checkpoints', } -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls From 8e6df9c4100823ea4d1f36fd9a021641a2a3408a Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 6 Oct 2023 23:30:55 -0700 Subject: [PATCH 44/59] fix lints --- streaming/base/converters/dataframe_to_mds.py | 2 +- streaming/base/storage/upload.py | 4 +++- streaming/base/util.py | 2 +- tests/base/converters/test_dataframe_to_mds.py | 17 ++++++++++------- tests/test_util.py | 6 +++--- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 7f0d3d548..80f1b8362 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -160,7 +160,7 @@ def write_mds(iterator: Iterable): else: raise RuntimeError('TaskContext.get() returns None') - if mds_path[1] == '': # only local + if mds_path[1] == '': # only local output = os.path.join(mds_path[0], f'{id}') partition_path = (output, '') else: diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index e98a3a6d9..59f3d6aaf 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -168,7 +168,9 @@ def __init__(self, if not exist_ok: raise FileExistsError(f'Directory is not empty: {self.local}') else: - logger.warning(f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.') + logger.warning( + f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.' + ) os.makedirs(self.local, exist_ok=True) diff --git a/streaming/base/util.py b/streaming/base/util.py index 2661251f6..83b65ea10 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -228,7 +228,7 @@ def _merge_index_from_list(index_file_urls: List[Any], out: Union[str, Tuple[str, str]], keep_local: bool = True, download_timeout: int = 60) -> None: - """Merge index.json from a list of subset of MDS dataset index.json to create joined index file. + """Merge index.json from a list of index files of MDS directories to create joined index. Args: index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index b86d84bc5..aba66a9cf 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -19,9 +19,10 @@ MY_PREFIX = 'train' MY_BUCKET = { 'gs://': 'mosaicml-composer-tests', - 's3://': 'mosaicml-internal-temporary-composer-testing' + 's3://': 'mosaicml-internal-temporary-composer-testing', + 'oci://': 'mosaicml-internal-checkpoints', } -MANUAL_INTEGRATION_TEST = False +MANUAL_INTEGRATION_TEST = True os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -30,7 +31,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = 'path/to/gooogle_api_credential.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join( + os.environ['HOME'], '.mosaic/mosaicml-research-gcs.json') os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) @@ -175,7 +177,8 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local: assert nsamples == sum([a['samples'] for a in mgi['shards']]) if not keep_local: assert not os.path.exists(os.path.join( - out, 'index.json')), 'merged index.json is found even through keep_local = False' + out, + 'index.json')), 'merged index.json is found even through keep_local = False' else: assert not os.path.exists(os.path.join( out, 'index.json')), 'merged index is created when merge_index=False' @@ -259,7 +262,7 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer assert not os.path.exists(os.path.join( out, 'index.json')), 'merged index is created when merge_index=False' - @pytest.mark.parametrize('scheme', ['gs://', 's3://']) + @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.parametrize('keep_local', [True]) # , False]) @pytest.mark.parametrize('merge_index', [True]) # , False]) @pytest.mark.usefixtures('manual_integration_dir') @@ -308,7 +311,7 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, assert not (os.path.exists(os.path.join( mds_path[0], 'index.json'))), 'merged index is created when merge_index=False' - @pytest.mark.parametrize('scheme', ['gs://', 's3://']) + @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.parametrize('merge_index', [True, False]) @pytest.mark.usefixtures('manual_integration_dir') @@ -353,7 +356,7 @@ def test_integration_conversion_local_and_remote(self, dataframe: Any, f'merged index is created at {mds_path[0]} when merge_index={merge_index} and ' + f'keep_local={keep_local}') - @pytest.mark.parametrize('scheme', ['gs://', 's3://']) + @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.usefixtures('manual_integration_dir') def test_integration_conversion_remote_only(self, dataframe: Any, manual_integration_dir: Any, scheme: str): diff --git a/tests/test_util.py b/tests/test_util.py index f3a9c340b..1201b78e7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -25,7 +25,7 @@ 's3://': 'mosaicml-internal-temporary-composer-testing', 'oci://': 'mosaicml-internal-checkpoints', } -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls @@ -34,8 +34,8 @@ def manual_integration_dir() -> Any: """Creates a temporary directory and then deletes it when the calling function is done.""" if MANUAL_INTEGRATION_TEST: - os.environ[ - 'GOOGLE_APPLICATION_CREDENTIALS'] = '/Users/xiaohan.zhang/.mosaic/mosaicml-research-nonprod-027345ddbdfd.json' # 'path/to/gooogle_api_credential.json' + os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join( + os.environ['HOME'], '.mosaic/mosaicml-research-gcs.json') os.environ.pop('AWS_ACCESS_KEY_ID', None) os.environ.pop('AWS_SECRET_ACCESS_KEY', None) os.environ.pop('AWS_SECURITY_TOKEN', None) From 36b4369301ffbf7be57718c18c46e180d251952f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sat, 7 Oct 2023 15:54:41 -0700 Subject: [PATCH 45/59] Turn off manual integratin --- streaming/base/storage/upload.py | 33 ++++++++++++------- streaming/base/util.py | 2 +- .../base/converters/test_dataframe_to_mds.py | 2 +- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 59f3d6aaf..371256f0c 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -75,7 +75,8 @@ def get(cls, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. @@ -141,7 +142,8 @@ def __init__(self, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -223,7 +225,8 @@ class S3Uploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -336,7 +339,8 @@ class GCSUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -485,7 +489,8 @@ class OCIUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -610,7 +615,8 @@ class AzureUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): Throw error if out already exists and not empty. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -697,7 +703,8 @@ class AzureDataLakeUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -781,7 +788,8 @@ class DatabricksUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -819,7 +827,8 @@ class DatabricksUnityCatalogUploader(DatabricksUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -867,7 +876,8 @@ class DBFSUploader(DatabricksUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, @@ -936,7 +946,8 @@ class LocalUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if out already exists and has contents. Defaults to ``False``. + exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already + exists and has contents. Defaults to ``False``. """ def __init__(self, diff --git a/streaming/base/util.py b/streaming/base/util.py index 83b65ea10..39afbec1a 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -318,8 +318,8 @@ def _merge_index_from_list(index_file_urls: List[Any], with open(merged_index_path, 'w') as outfile: json.dump(obj, outfile) + # Move merged index from temp path to local part in out # Upload merged index to remote if out has remote part - # Otherwise, move it from temp root to out location shutil.move(merged_index_path, cu.local) if cu.remote is not None: cu.upload_file(index_basename) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index aba66a9cf..d36c6f921 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -22,7 +22,7 @@ 's3://': 'mosaicml-internal-temporary-composer-testing', 'oci://': 'mosaicml-internal-checkpoints', } -MANUAL_INTEGRATION_TEST = True +MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls From 828e74d403096077984ca6a92004955a7e364b73 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sun, 8 Oct 2023 22:26:18 -0700 Subject: [PATCH 46/59] Address comments --- streaming/base/storage/upload.py | 2 - streaming/base/util.py | 36 +++++++------ tests/test_list_objects.py | 90 -------------------------------- tests/test_upload.py | 71 ++++++++++++++++++++++++- 4 files changed, 91 insertions(+), 108 deletions(-) delete mode 100644 tests/test_list_objects.py diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 371256f0c..ffbfef33a 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -464,8 +464,6 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: except KeyError: return [] elif self.authentication == GCSAuthentication.SERVICE_ACCOUNT: - #from google.cloud.storage import Blob, Bucket - #blob = Blob(str(obj.path).lstrip('/'), Bucket(self.gcs_client, obj.netloc)) prefix = os.path.join(str(obj.path).lstrip('/'), prefix) return [ b.name for b in self.gcs_client.get_bucket(bucket_name).list_blobs(prefix=prefix) diff --git a/streaming/base/util.py b/streaming/base/util.py index 39afbec1a..fc86db21d 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -224,7 +224,7 @@ def merge_index(*args, **kwargs): # pyright: ignore raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') -def _merge_index_from_list(index_file_urls: List[Any], +def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], out: Union[str, Tuple[str, str]], keep_local: bool = True, download_timeout: int = 60) -> None: @@ -281,8 +281,8 @@ def _merge_index_from_list(index_file_urls: List[Any], scheme, bucket, path = obj.scheme, obj.netloc, obj.path if scheme == '' and bucket == '' and path == '': raise FileNotFoundError( - f'Check data availability! local index {url[0]} is not accessible. remote index {url[1]} does not have a valid url format' - ) + f'Check data availability! local index {url[0]} is not accessible.' + + 'remote index {url[1]} does not have a valid url format') dest = os.path.join(temp_root, path.lstrip('/')) try: @@ -290,7 +290,7 @@ def _merge_index_from_list(index_file_urls: List[Any], except Exception as ex: raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex - if not (os.path.exists(dest)): + if not os.path.exists(dest): raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') partitions.append(dest) @@ -334,18 +334,24 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = Args: out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. - :a local directory, merge index happens locally - :a remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location - :a (local_dir, remote_dir), check if sub-directories index.json file present locally + :A local directory, merge index happens locally + :A remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location + :A (local_dir, remote_dir), check if sub-directories index.json file present locally If yes, then merge locally and upload to remote_dir . If not, download all the sub-directories index.json from remote to local , merge locally, and upload to remote_dir . keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ from streaming.base.storage.upload import CloudUploader def not_merged_index(index_file_path: str, out: str): - """Check if index_file_path is the merged index at folder out.""" + """Check if index_file_path is the merged index at folder out. + + Args: + index_file_path (str): the path to index.json file + out (str): remote or local url of a folder + Return: + (bool): no if index.json sits in out instead of in the subfolders of out + """ prefix = str(urllib.parse.urlparse(out).path) return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') @@ -357,16 +363,16 @@ def not_merged_index(index_file_path: str, out: str): local_index_files = [] cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) - for o in cl.list_objects(): - if o.endswith('.json') and not_merged_index(o, cu.local): - local_index_files.append(o) + for file in cl.list_objects(): + if file.endswith('.json') and not_merged_index(file, cu.local): + local_index_files.append(file) if cu.remote: obj = urllib.parse.urlparse(cu.remote) remote_index_files = [] - for o in cu.list_objects(): - if o.endswith(get_index_basename()) and not_merged_index(o, cu.remote): - remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, o)) + for file in cu.list_objects(): + if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote): + remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, file)) if len(local_index_files) == len(remote_index_files): _merge_index_from_list(list(zip(local_index_files, remote_index_files)), out, diff --git a/tests/test_list_objects.py b/tests/test_list_objects.py deleted file mode 100644 index f43208628..000000000 --- a/tests/test_list_objects.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2023 MosaicML Streaming authors -# SPDX-License-Identifier: Apache-2.0 - -import os -import tempfile -from typing import Any, Tuple -from unittest.mock import Mock, patch - -import boto3 -import pytest - -from streaming.base.storage.upload import CloudUploader -from tests.conftest import MY_BUCKET - -MY_PREFIX = 'train' - - -@pytest.fixture(scope='function') -def remote_local_dir() -> Any: - """Creates a temporary directory and then deletes it when the calling function is done.""" - - def _method(cloud_prefix: str = '') -> Tuple[str, str]: - try: - mock_local_dir = tempfile.TemporaryDirectory() - mock_local = mock_local_dir.name - mock_remote = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX) - return mock_remote, mock_local - finally: - mock_local_dir.cleanup() # pyright: ignore - - return _method - - -class TestS3Client: - - @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') - def test_list_objects_from_s3(self, remote_local_dir: Any): - with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: - file_name = tmp.name.split(os.sep)[-1] - mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') - client = boto3.client('s3', region_name='us-east-1') - client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') - - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) - objs = cu.list_objects(mock_remote_dir) - assert isinstance(objs, list) - - @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') - def test_clienterror_exception(self, remote_local_dir: Any): - mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) - objs = cu.list_objects() - if objs: - assert (len(objs) == 0) - - @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') - def test_invalid_cloud_prefix(self, remote_local_dir: Any): - with pytest.raises(ValueError): - mock_remote_dir, _ = remote_local_dir(cloud_prefix='s9://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) - _ = cu.list_objects() - - -class TestGCSClient: - - @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_dir') - def test_invalid_cloud_prefix(self, remote_local_dir: Any): - with pytest.raises(ValueError): - mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs9://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) - _ = cu.list_objects() - - def test_no_credentials_error(self, remote_local_dir: Any): - """Ensure we raise a value error correctly if we have no credentials available.""" - with pytest.raises(ValueError): - mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) - _ = cu.list_objects() - - -class TestListObjects: - - @patch('streaming.base.storage.LocalUploader.list_objects') - @pytest.mark.usefixtures('remote_local_dir') - def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, - remote_local_dir: Any): - mock_remote_dir, _ = remote_local_dir() - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) - cu.list_objects() - mocked_requests.assert_called_once() diff --git a/tests/test_upload.py b/tests/test_upload.py index 1c63746fd..dc4ae9201 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -7,13 +7,32 @@ from typing import Any, List, Tuple from unittest.mock import Mock, patch +import boto3 import pytest from streaming.base.storage.upload import (AzureDataLakeUploader, AzureUploader, CloudUploader, DatabricksUnityCatalogUploader, DBFSUploader, GCSAuthentication, GCSUploader, LocalUploader, S3Uploader) -from tests.conftest import R2_URL +from tests.conftest import MY_BUCKET, R2_URL + +MY_PREFIX = 'train' + + +@pytest.fixture(scope='function') +def remote_local_dir() -> Any: + """Creates a temporary directory and then deletes it when the calling function is done.""" + + def _method(cloud_prefix: str = '') -> Tuple[str, str]: + try: + mock_local_dir = tempfile.TemporaryDirectory() + mock_local = mock_local_dir.name + mock_remote = os.path.join(cloud_prefix, MY_BUCKET, MY_PREFIX) + return mock_remote, mock_local + finally: + mock_local_dir.cleanup() # pyright: ignore + + return _method class TestCloudUploader: @@ -92,6 +111,15 @@ def test_check_bucket_exists_exception(self, out: str): with pytest.raises(botocore.exceptions.ClientError): _ = CloudUploader.get(out=out) + @patch('streaming.base.storage.LocalUploader.list_objects') + @pytest.mark.usefixtures('remote_local_dir') + def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, + remote_local_dir: Any): + mock_remote_dir, _ = remote_local_dir() + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu.list_objects() + mocked_requests.assert_called_once() + class TestS3Uploader: @@ -157,6 +185,33 @@ def test_check_bucket_exists_exception(self, out: str): with pytest.raises(botocore.exceptions.ClientError): _ = S3Uploader(out=out) + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') + def test_list_objects_from_s3(self, remote_local_dir: Any): + with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: + file_name = tmp.name.split(os.sep)[-1] + mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') + client = boto3.client('s3', region_name='us-east-1') + client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') + + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + objs = cu.list_objects(mock_remote_dir) + assert isinstance(objs, list) + + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') + def test_clienterror_exception(self, remote_local_dir: Any): + mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + objs = cu.list_objects() + if objs: + assert (len(objs) == 0) + + @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') + def test_invalid_cloud_prefix(self, remote_local_dir: Any): + with pytest.raises(ValueError): + mock_remote_dir, _ = remote_local_dir(cloud_prefix='s9://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + _ = cu.list_objects() + class TestGCSUploader: @@ -252,6 +307,20 @@ def test_no_authentication(self, out: str): f'https://docs.mosaicml.com/projects/mcli/en/latest/resources/secrets/gcp.html.')): _ = GCSUploader(out=out) + @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test', 'remote_local_dir') + def test_invalid_cloud_prefix(self, remote_local_dir: Any): + with pytest.raises(ValueError): + mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs9://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + _ = cu.list_objects() + + def test_no_credentials_error(self, remote_local_dir: Any): + """Ensure we raise a value error correctly if we have no credentials available.""" + with pytest.raises(ValueError): + mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs://') + cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + _ = cu.list_objects() + class TestAzureUploader: From 8e616d85712e39e83829d60ca47ce59559cca1ec Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Sun, 8 Oct 2023 22:30:44 -0700 Subject: [PATCH 47/59] Update --- tests/base/converters/test_dataframe_to_mds.py | 6 +++--- tests/test_util.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index d36c6f921..d56f9c3b6 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -18,9 +18,9 @@ MY_PREFIX = 'train' MY_BUCKET = { - 'gs://': 'mosaicml-composer-tests', - 's3://': 'mosaicml-internal-temporary-composer-testing', - 'oci://': 'mosaicml-internal-checkpoints', + 'gs://': 'testing-bucket', + 's3://': 'testing-bucket', + 'oci://': 'testing-bucket', } MANUAL_INTEGRATION_TEST = False os.environ[ diff --git a/tests/test_util.py b/tests/test_util.py index 1201b78e7..8f378b9af 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -21,9 +21,9 @@ MY_PREFIX = 'train_' + str(time.time()) MY_BUCKET = { - 'gs://': 'mosaicml-composer-tests', - 's3://': 'mosaicml-internal-temporary-composer-testing', - 'oci://': 'mosaicml-internal-checkpoints', + 'gs://': 'testing-bucket', + 's3://': 'testing-bucket', + 'oci://': 'testing-bucket', } MANUAL_INTEGRATION_TEST = False os.environ[ From 282973a3e7f8fdfe8d30238ff1786c6892ec1dd1 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 9 Oct 2023 22:08:37 -0700 Subject: [PATCH 48/59] updates --- streaming/base/converters/dataframe_to_mds.py | 1 + streaming/base/util.py | 4 +- .../base/converters/test_dataframe_to_mds.py | 11 +- tests/test_util.py | 168 ++++++++++++------ 4 files changed, 119 insertions(+), 65 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 80f1b8362..f83b796cd 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -153,6 +153,7 @@ def dataframeToMDS(dataframe: DataFrame, """ def write_mds(iterator: Iterable): + """worker node writes iterable to MDS datasets locally""" context = TaskContext.get() if context is not None: diff --git a/streaming/base/util.py b/streaming/base/util.py index fc86db21d..582725003 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -237,7 +237,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], The pattern of index_file_urls and corresponding reaction is one of: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are tuple (local, remote). All urls are accessible locally -> no download - 3. All urls are tuple (local, remote). At least one url is not accessible locally -> download all + 3. All urls are tuple (local, remote). Download url thtat is not accessible locally 4. All urls are str (remote) -> download all out (Union[str, Tuple[str, str]]): path to put the merged index file @@ -295,7 +295,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], partitions.append(dest) - # merge index files into shards + # merge shards from all index files shards = [] for partition_index in partitions: p = Path(partition_index) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index d56f9c3b6..32ce5b5b5 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -266,13 +266,10 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer @pytest.mark.parametrize('keep_local', [True]) # , False]) @pytest.mark.parametrize('merge_index', [True]) # , False]) @pytest.mark.usefixtures('manual_integration_dir') + @pytest.mark.remote def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, merge_index: bool, keep_local: bool, manual_integration_dir: Any): - if not MANUAL_INTEGRATION_TEST: - pytest.skip( - 'Overlap with integration tests. But better figure out how to run this test ' + - 'suite with Mock.') mock_local, mock_remote = manual_integration_dir(scheme) out = (mock_local, mock_remote) mds_kwargs = { @@ -315,12 +312,11 @@ def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.parametrize('merge_index', [True, False]) @pytest.mark.usefixtures('manual_integration_dir') + @pytest.mark.remote def test_integration_conversion_local_and_remote(self, dataframe: Any, manual_integration_dir: Any, merge_index: bool, keep_local: bool, scheme: str): - if not MANUAL_INTEGRATION_TEST: - pytest.skip('run local only. CI cluster does not have GCS service acct set up.') out = manual_integration_dir(scheme) mds_kwargs = { 'out': out, @@ -358,10 +354,9 @@ def test_integration_conversion_local_and_remote(self, dataframe: Any, @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.usefixtures('manual_integration_dir') + @pytest.mark.remote def test_integration_conversion_remote_only(self, dataframe: Any, manual_integration_dir: Any, scheme: str): - if not MANUAL_INTEGRATION_TEST: - pytest.skip('run local only. CI cluster does not have GCS service acct set up.') _, remote = manual_integration_dir('s3://') mds_kwargs = { 'out': remote, diff --git a/tests/test_util.py b/tests/test_util.py index 8f378b9af..d122cb071 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -241,11 +241,12 @@ def get_expected(mds_root: str): @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) -@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3, 4, 5]) +@pytest.mark.parametrize('index_file_urls_pattern', [4,5]) @pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_merge_index_from_list(manual_integration_dir: Any, keep_local: bool, +@pytest.mark.remote +def test_merge_index_from_list_remote(manual_integration_dir: Any, keep_local: bool, index_file_urls_pattern: int, out_format: str, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download @@ -268,17 +269,84 @@ def not_merged_index(index_file_path: str, out: str): local, remote = manual_integration_dir(scheme) - if out_format != 'local': - if not MANUAL_INTEGRATION_TEST: - pytest.skip('Require cloud credentials. ' + - 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - if out_format == 'remote': - out = remote - else: - out = (local, remote) - mds_out = (local, remote) + if out_format == 'remote': + out = remote else: - mds_out = out = local + out = (local, remote) + mds_out = (local, remote) + + spark = SparkSession.builder.getOrCreate() # pyright: ignore + schema = StructType([ + StructField('id', IntegerType(), nullable=False), + StructField('name', StringType(), nullable=False), + StructField('amount', DecimalType(10, 2), nullable=False) + ]) + data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), + (3, 'Charlie', Decimal('987.65'))] + df = spark.createDataFrame(data=data, schema=schema).repartition(3) + mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} + dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + + local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) + local_index_files = [ + o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) + ] + + if index_file_urls_pattern == 4: + + remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) + remote_index_files = [ + os.path.join(scheme, MY_BUCKET[scheme], o) + for o in remote_cu.list_objects() + if o.endswith('.json') and not_merged_index(o, remote) + ] + with tempfile.TemporaryDirectory() as a_temporary_folder: + non_exist_local_files = [ + os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files + ] + index_file_urls = list(zip(non_exist_local_files, remote_index_files)) + merge_index(index_file_urls, out, keep_local=keep_local) + + if index_file_urls_pattern == 5: + + remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) + remote_index_files = [ + os.path.join(scheme, MY_BUCKET[scheme], o) + for o in remote_cu.list_objects() + if o.endswith('.json') and not_merged_index(o, remote) + ] + merge_index(remote_index_files, out, keep_local=keep_local) + + integrity_check(out, keep_local=keep_local) + +@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) +@pytest.mark.usefixtures('manual_integration_dir') +@pytest.mark.parametrize('keep_local', [True, False]) +def test_merge_index_from_list_local(manual_integration_dir: Any, keep_local: bool, + index_file_urls_pattern: int): + """Validate the final merge index json for following patterns of index_file_urls: + 1. All urls are str (local). All urls are accessible locally -> no download + 2. All urls are str (local). At least one url is unaccessible locally -> Error + 3. All urls are tuple (local, remote). All urls are accessible locally -> no download + 4. All urls are tuple (local, remote). At least one url is not accessible locally -> download all + 5. All urls are str (remote) -> download all + """ + from decimal import Decimal + + from pyspark.sql import SparkSession + from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType + + from streaming.base.converters import dataframeToMDS + + def not_merged_index(index_file_path: str, out: str): + """Check if index_file_path is the merged index at folder out.""" + prefix = str(urllib.parse.urlparse(out).path) + return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') + + local, remote = manual_integration_dir() + + mds_out = out = local + scheme = 's3://' spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ @@ -324,50 +392,46 @@ def not_merged_index(index_file_path: str, out: str): index_file_urls = list(zip(local_index_files, remote_index_files)) merge_index(index_file_urls, out, keep_local=keep_local) - if index_file_urls_pattern == 4: - if out_format == 'local': - return + integrity_check(out, keep_local=keep_local) - if not MANUAL_INTEGRATION_TEST: - pytest.skip('Require cloud credentials. ' + - 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) - remote_index_files = [ - os.path.join(scheme, MY_BUCKET[scheme], o) - for o in remote_cu.list_objects() - if o.endswith('.json') and not_merged_index(o, remote) - ] - with tempfile.TemporaryDirectory() as a_temporary_folder: - non_exist_local_files = [ - os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files - ] - index_file_urls = list(zip(non_exist_local_files, remote_index_files)) - merge_index(index_file_urls, out, keep_local=keep_local) +@pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) +@pytest.mark.parametrize('keep_local', [False, True]) +def test_merge_index_from_root_local(manual_integration_dir: Any, n_partitions: int, + keep_local: bool): + from decimal import Decimal - if index_file_urls_pattern == 5: - if out_format == 'local': - return + from pyspark.sql import SparkSession + from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - if not MANUAL_INTEGRATION_TEST: - pytest.skip('Require cloud credentials. ' + - 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) - remote_index_files = [ - os.path.join(scheme, MY_BUCKET[scheme], o) - for o in remote_cu.list_objects() - if o.endswith('.json') and not_merged_index(o, remote) - ] - merge_index(remote_index_files, out, keep_local=keep_local) + from streaming.base.converters import dataframeToMDS - integrity_check(out, keep_local=keep_local) + out, _ = manual_integration_dir() + spark = SparkSession.builder.getOrCreate() # pyright: ignore + schema = StructType([ + StructField('id', IntegerType(), nullable=False), + StructField('name', StringType(), nullable=False), + StructField('amount', DecimalType(10, 2), nullable=False) + ]) + + data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), + (3, 'Charlie', Decimal('987.65'))] + + df = spark.createDataFrame(data=data, schema=schema).repartition(n_partitions) + + mds_kwargs = {'out': out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': keep_local} + + mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) + merge_index(mds_path, keep_local=keep_local) + integrity_check(mds_path, keep_local=keep_local) @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) -@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) +@pytest.mark.parametrize('out_format', ['remote', 'tuple']) @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) -def test_merge_index_from_root(manual_integration_dir: Any, out_format: str, n_partitions: int, +@pytest.mark.remote +def test_merge_index_from_root_remote(manual_integration_dir: Any, out_format: str, n_partitions: int, keep_local: bool, scheme: str): from decimal import Decimal @@ -376,16 +440,10 @@ def test_merge_index_from_root(manual_integration_dir: Any, out_format: str, n_p from streaming.base.converters import dataframeToMDS - if out_format == 'remote' or out_format == 'tuple': - if not MANUAL_INTEGRATION_TEST: - pytest.skip('Require cloud credentials. ' + - 'skipping. Set MANUAL_INTEGRATION_TEST=True to run the check manually!') - if out_format == 'remote': - _, out = manual_integration_dir(scheme) - else: - out = manual_integration_dir(scheme) + if out_format == 'remote': + _, out = manual_integration_dir(scheme) else: - out, _ = manual_integration_dir(scheme) + out = manual_integration_dir(scheme) spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ From ebacc8776def6188b395d36ffe411518188ad24d Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 9 Oct 2023 22:10:16 -0700 Subject: [PATCH 49/59] Fix lints --- streaming/base/converters/dataframe_to_mds.py | 2 +- tests/test_util.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index f83b796cd..8749bbdec 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -153,7 +153,7 @@ def dataframeToMDS(dataframe: DataFrame, """ def write_mds(iterator: Iterable): - """worker node writes iterable to MDS datasets locally""" + """Worker node writes iterable to MDS datasets locally.""" context = TaskContext.get() if context is not None: diff --git a/tests/test_util.py b/tests/test_util.py index d122cb071..a9af5e805 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -241,13 +241,13 @@ def get_expected(mds_root: str): @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) -@pytest.mark.parametrize('index_file_urls_pattern', [4,5]) +@pytest.mark.parametrize('index_file_urls_pattern', [4, 5]) @pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) @pytest.mark.remote def test_merge_index_from_list_remote(manual_integration_dir: Any, keep_local: bool, - index_file_urls_pattern: int, out_format: str, scheme: str): + index_file_urls_pattern: int, out_format: str, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error @@ -319,11 +319,12 @@ def not_merged_index(index_file_path: str, out: str): integrity_check(out, keep_local=keep_local) + @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) @pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) def test_merge_index_from_list_local(manual_integration_dir: Any, keep_local: bool, - index_file_urls_pattern: int): + index_file_urls_pattern: int): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error @@ -343,7 +344,7 @@ def not_merged_index(index_file_path: str, out: str): prefix = str(urllib.parse.urlparse(out).path) return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') - local, remote = manual_integration_dir() + local, _ = manual_integration_dir() mds_out = out = local scheme = 's3://' @@ -398,7 +399,7 @@ def not_merged_index(index_file_path: str, out: str): @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) def test_merge_index_from_root_local(manual_integration_dir: Any, n_partitions: int, - keep_local: bool): + keep_local: bool): from decimal import Decimal from pyspark.sql import SparkSession @@ -426,13 +427,14 @@ def test_merge_index_from_root_local(manual_integration_dir: Any, n_partitions: merge_index(mds_path, keep_local=keep_local) integrity_check(mds_path, keep_local=keep_local) + @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) @pytest.mark.parametrize('out_format', ['remote', 'tuple']) @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) @pytest.mark.remote -def test_merge_index_from_root_remote(manual_integration_dir: Any, out_format: str, n_partitions: int, - keep_local: bool, scheme: str): +def test_merge_index_from_root_remote(manual_integration_dir: Any, out_format: str, + n_partitions: int, keep_local: bool, scheme: str): from decimal import Decimal from pyspark.sql import SparkSession From 90dccce861c549486a9bb304104fa663deafde1f Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 9 Oct 2023 22:23:48 -0700 Subject: [PATCH 50/59] remove integration tests --- .../base/converters/test_dataframe_to_mds.py | 190 ----------------- tests/test_util.py | 197 +----------------- 2 files changed, 4 insertions(+), 383 deletions(-) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index 32ce5b5b5..e4fbb95a7 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -22,79 +22,10 @@ 's3://': 'testing-bucket', 'oci://': 'testing-bucket', } -MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls -@pytest.fixture(scope='function', autouse=True) -def manual_integration_dir() -> Any: - """Creates a temporary directory and then deletes it when the calling function is done.""" - if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join( - os.environ['HOME'], '.mosaic/mosaicml-research-gcs.json') - os.environ.pop('AWS_ACCESS_KEY_ID', None) - os.environ.pop('AWS_SECRET_ACCESS_KEY', None) - os.environ.pop('AWS_SECURITY_TOKEN', None) - os.environ.pop('AWS_SESSION_TOKEN', None) - os.environ['AWS_PROFILE'] = 'temporary' - - tmp_dir = mkdtemp() - - def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: - mock_local_dir = tmp_dir # mkdtemp() - mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET[cloud_prefix], MY_PREFIX) - return mock_local_dir, mock_remote_dir - - try: - yield _method - finally: - shutil.rmtree(tmp_dir, ignore_errors=True) # pyright: ignore - if MANUAL_INTEGRATION_TEST: - try: - from google.cloud.storage import Client - storage_client = Client() - bucket = storage_client.get_bucket(MY_BUCKET['gs://']) - blobs = bucket.list_blobs(prefix=MY_PREFIX) - for blob in blobs: - blob.delete() - except ImportError: - raise ImportError('google.cloud.storage is not imported correctly.') - - try: - import boto3 - s3 = boto3.client('s3') - response = s3.list_objects_v2(Bucket=MY_BUCKET['s3://'], Prefix=MY_PREFIX) - objects_to_delete = [{'Key': obj['Key']} for obj in response.get('Contents', [])] - if objects_to_delete: - s3.delete_objects(Bucket=MY_BUCKET['s3://'], - Delete={'Objects': objects_to_delete}) - except ImportError: - raise ImportError('boto3 is not imported correctly.') - - try: - import oci - client = oci.object_storage.ObjectStorageClient(oci.config.from_file()) - response = client.list_objects( - namespace_name=client.get_namespace().data, - bucket_name=MY_BUCKET['oci://'], - fields=['name'], - prefix=MY_PREFIX, - ) - - # Delete the objects - for obj in response.data.objects: - client.delete_object( - namespace_name=client.get_namespace().data, - bucket_name=MY_BUCKET['oci://'], - object_name=obj.name, - ) - print(f'Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}') - - except ImportError: - raise ImportError('boto3 is not imported correctly.') - - class TestDataFrameToMDS: @pytest.fixture @@ -262,124 +193,3 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer assert not os.path.exists(os.path.join( out, 'index.json')), 'merged index is created when merge_index=False' - @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) - @pytest.mark.parametrize('keep_local', [True]) # , False]) - @pytest.mark.parametrize('merge_index', [True]) # , False]) - @pytest.mark.usefixtures('manual_integration_dir') - @pytest.mark.remote - def test_patch_conversion_local_and_remote(self, dataframe: Any, scheme: str, - merge_index: bool, keep_local: bool, - manual_integration_dir: Any): - mock_local, mock_remote = manual_integration_dir(scheme) - out = (mock_local, mock_remote) - mds_kwargs = { - 'out': out, - 'columns': { - 'id': 'str', - 'dept': 'str' - }, - 'keep_local': keep_local, - 'compression': 'zstd:7', - 'hashes': ['sha1', 'xxh64'], - 'size_limit': 1 << 26 - } - - mds_path, fail_count = dataframeToMDS(dataframe, - merge_index=merge_index, - mds_kwargs=mds_kwargs) - - assert fail_count == 0, 'some records were not converted correctly' - assert out == mds_path, f'returned mds_path: {mds_path} is not the same as out: {out}' - - if not keep_local: - assert not os.path.exists(mds_path[0]), 'local folder were not removed' - return - - assert len(os.listdir(mds_path[0])) > 0, f'{mds_path[0]} is empty' - for d in os.listdir(mds_path[0]): - if os.path.isdir(os.path.join(mds_path[0], d)): - assert os.path.exists(os.path.join( - mds_path[0], d, 'index.json')), f'No index.json found in subdirectory {d}' - - if merge_index == True: - assert os.path.exists(os.path.join(mds_path[0], - 'index.json')), 'No merged index.json found' - else: - assert not (os.path.exists(os.path.join( - mds_path[0], 'index.json'))), 'merged index is created when merge_index=False' - - @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) - @pytest.mark.parametrize('keep_local', [True, False]) - @pytest.mark.parametrize('merge_index', [True, False]) - @pytest.mark.usefixtures('manual_integration_dir') - @pytest.mark.remote - def test_integration_conversion_local_and_remote(self, dataframe: Any, - manual_integration_dir: Any, - merge_index: bool, keep_local: bool, - scheme: str): - out = manual_integration_dir(scheme) - mds_kwargs = { - 'out': out, - 'columns': { - 'id': 'str', - 'dept': 'str' - }, - 'keep_local': keep_local, - 'compression': 'zstd:7', - 'hashes': ['sha1', 'xxh64'], - 'size_limit': 1 << 26 - } - - mds_path, _ = dataframeToMDS(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs) - - assert out == mds_path, f'returned mds_path: {mds_path} is not the same as out: {out}' - - if not keep_local: - assert not os.path.exists(mds_path[0]), 'local folder were not removed' - return - - assert len(os.listdir(mds_path[0])) > 0, f'{mds_path[0]} is empty' - for d in os.listdir(mds_path[0]): - if os.path.isdir(os.path.join(mds_path[0], d)): - assert os.path.exists(os.path.join( - mds_path[0], d, 'index.json')), f'No index.json found in subdirectory {d}' - - if merge_index == True: - assert os.path.exists(os.path.join(mds_path[0], - 'index.json')), 'No merged index.json found' - else: - assert not os.path.exists(os.path.join(mds_path[0], 'index.json')), ( - f'merged index is created at {mds_path[0]} when merge_index={merge_index} and ' + - f'keep_local={keep_local}') - - @pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) - @pytest.mark.usefixtures('manual_integration_dir') - @pytest.mark.remote - def test_integration_conversion_remote_only(self, dataframe: Any, manual_integration_dir: Any, - scheme: str): - _, remote = manual_integration_dir('s3://') - mds_kwargs = { - 'out': remote, - 'columns': { - 'id': 'str', - 'dept': 'str' - }, - } - - mds_path, _ = dataframeToMDS(dataframe, merge_index=True, mds_kwargs=mds_kwargs) - - assert len(mds_path) == 2, 'returned mds is a str but should be a tuple (local, remote)' - assert not (os.path.exists(os.path.join( - mds_path[0], 'index.json'))), 'Local merged index was not removed successfully' - assert len(os.listdir(mds_path[0])) > 0, f'{mds_path[0]} is not empty' - - def test_simple_remote(self, dataframe: Any): - if not MANUAL_INTEGRATION_TEST: - pytest.skip('run local only. CI cluster does not have GCS service acct set up.') - - out = 'gs://mosaicml-composer-tests/test_df2mds' - - with MDSWriter(out=out, columns={'id': 'str', 'dept': 'str'}) as mds_writer: - d = dataframe.toPandas().to_dict('records') - for row in d: - mds_writer.write(row) diff --git a/tests/test_util.py b/tests/test_util.py index a9af5e805..539e91f6c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -25,79 +25,9 @@ 's3://': 'testing-bucket', 'oci://': 'testing-bucket', } -MANUAL_INTEGRATION_TEST = False os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls - -@pytest.fixture(scope='function', autouse=True) -def manual_integration_dir() -> Any: - """Creates a temporary directory and then deletes it when the calling function is done.""" - if MANUAL_INTEGRATION_TEST: - os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join( - os.environ['HOME'], '.mosaic/mosaicml-research-gcs.json') - os.environ.pop('AWS_ACCESS_KEY_ID', None) - os.environ.pop('AWS_SECRET_ACCESS_KEY', None) - os.environ.pop('AWS_SECURITY_TOKEN', None) - os.environ.pop('AWS_SESSION_TOKEN', None) - os.environ['AWS_PROFILE'] = 'temporary' - - tmp_dir = tempfile.mkdtemp() - - def _method(cloud_prefix: str = 'gs://') -> Tuple[str, str]: - mock_local_dir = tmp_dir - mock_remote_dir = os.path.join(cloud_prefix, MY_BUCKET[cloud_prefix], MY_PREFIX) - return mock_local_dir, mock_remote_dir - - try: - yield _method - finally: - shutil.rmtree(tmp_dir, ignore_errors=True) # pyright: ignore - if MANUAL_INTEGRATION_TEST: - try: - from google.cloud.storage import Client - storage_client = Client() - bucket = storage_client.get_bucket(MY_BUCKET['gs://']) - blobs = bucket.list_blobs(prefix=MY_PREFIX) - for blob in blobs: - blob.delete() - except ImportError: - raise ImportError('google.cloud.storage is not imported correctly.') - - try: - import boto3 - s3 = boto3.client('s3') - response = s3.list_objects_v2(Bucket=MY_BUCKET['s3://'], Prefix=MY_PREFIX) - objects_to_delete = [{'Key': obj['Key']} for obj in response.get('Contents', [])] - if objects_to_delete: - s3.delete_objects(Bucket=MY_BUCKET['s3://'], - Delete={'Objects': objects_to_delete}) - except ImportError: - raise ImportError('boto3 is not imported correctly.') - - try: - import oci - client = oci.object_storage.ObjectStorageClient(oci.config.from_file()) - response = client.list_objects( - namespace_name=client.get_namespace().data, - bucket_name=MY_BUCKET['oci://'], - fields=['name'], - prefix=MY_PREFIX, - ) - - # Delete the objects - for obj in response.data.objects: - client.delete_object( - namespace_name=client.get_namespace().data, - bucket_name=MY_BUCKET['oci://'], - object_name=obj.name, - ) - print(f'Deleted {len(response.data.objects)} objects with prefix: {MY_PREFIX}') - - except ImportError: - raise ImportError('boto3 is not imported correctly.') - - @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) def test_get_list_arg(text: str, expected_output: List[Optional[str]]): @@ -240,90 +170,9 @@ def get_expected(mds_root: str): assert n_shard_files == expected_n_shard_files, f'expected {expected_n_shard_files} shard files but got {n_shard_files}' -@pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) -@pytest.mark.parametrize('index_file_urls_pattern', [4, 5]) -@pytest.mark.parametrize('out_format', ['remote', 'local', 'tuple']) -@pytest.mark.usefixtures('manual_integration_dir') -@pytest.mark.parametrize('keep_local', [True, False]) -@pytest.mark.remote -def test_merge_index_from_list_remote(manual_integration_dir: Any, keep_local: bool, - index_file_urls_pattern: int, out_format: str, scheme: str): - """Validate the final merge index json for following patterns of index_file_urls: - 1. All urls are str (local). All urls are accessible locally -> no download - 2. All urls are str (local). At least one url is unaccessible locally -> Error - 3. All urls are tuple (local, remote). All urls are accessible locally -> no download - 4. All urls are tuple (local, remote). At least one url is not accessible locally -> download all - 5. All urls are str (remote) -> download all - """ - from decimal import Decimal - - from pyspark.sql import SparkSession - from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - - from streaming.base.converters import dataframeToMDS - - def not_merged_index(index_file_path: str, out: str): - """Check if index_file_path is the merged index at folder out.""" - prefix = str(urllib.parse.urlparse(out).path) - return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') - - local, remote = manual_integration_dir(scheme) - - if out_format == 'remote': - out = remote - else: - out = (local, remote) - mds_out = (local, remote) - - spark = SparkSession.builder.getOrCreate() # pyright: ignore - schema = StructType([ - StructField('id', IntegerType(), nullable=False), - StructField('name', StringType(), nullable=False), - StructField('amount', DecimalType(10, 2), nullable=False) - ]) - data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), - (3, 'Charlie', Decimal('987.65'))] - df = spark.createDataFrame(data=data, schema=schema).repartition(3) - mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} - dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - - local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) - local_index_files = [ - o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) - ] - - if index_file_urls_pattern == 4: - - remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) - remote_index_files = [ - os.path.join(scheme, MY_BUCKET[scheme], o) - for o in remote_cu.list_objects() - if o.endswith('.json') and not_merged_index(o, remote) - ] - with tempfile.TemporaryDirectory() as a_temporary_folder: - non_exist_local_files = [ - os.path.join(a_temporary_folder, os.path.basename(s)) for s in local_index_files - ] - index_file_urls = list(zip(non_exist_local_files, remote_index_files)) - merge_index(index_file_urls, out, keep_local=keep_local) - - if index_file_urls_pattern == 5: - - remote_cu = CloudUploader.get(remote, exist_ok=True, keep_local=True) - remote_index_files = [ - os.path.join(scheme, MY_BUCKET[scheme], o) - for o in remote_cu.list_objects() - if o.endswith('.json') and not_merged_index(o, remote) - ] - merge_index(remote_index_files, out, keep_local=keep_local) - - integrity_check(out, keep_local=keep_local) - - @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) -@pytest.mark.usefixtures('manual_integration_dir') @pytest.mark.parametrize('keep_local', [True, False]) -def test_merge_index_from_list_local(manual_integration_dir: Any, keep_local: bool, +def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, index_file_urls_pattern: int): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download @@ -344,7 +193,7 @@ def not_merged_index(index_file_path: str, out: str): prefix = str(urllib.parse.urlparse(out).path) return os.path.dirname(index_file_path).strip('/') != prefix.strip('/') - local, _ = manual_integration_dir() + local, _ = local_remote_dir mds_out = out = local scheme = 's3://' @@ -398,7 +247,7 @@ def not_merged_index(index_file_path: str, out: str): @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) -def test_merge_index_from_root_local(manual_integration_dir: Any, n_partitions: int, +def test_merge_index_from_root_local(local_remote_dir: [Tuple[str, str]], n_partitions: int, keep_local: bool): from decimal import Decimal @@ -407,45 +256,7 @@ def test_merge_index_from_root_local(manual_integration_dir: Any, n_partitions: from streaming.base.converters import dataframeToMDS - out, _ = manual_integration_dir() - - spark = SparkSession.builder.getOrCreate() # pyright: ignore - schema = StructType([ - StructField('id', IntegerType(), nullable=False), - StructField('name', StringType(), nullable=False), - StructField('amount', DecimalType(10, 2), nullable=False) - ]) - - data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')), - (3, 'Charlie', Decimal('987.65'))] - - df = spark.createDataFrame(data=data, schema=schema).repartition(n_partitions) - - mds_kwargs = {'out': out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': keep_local} - - mds_path, _ = dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - merge_index(mds_path, keep_local=keep_local) - integrity_check(mds_path, keep_local=keep_local) - - -@pytest.mark.parametrize('scheme', ['oci://', 'gs://', 's3://']) -@pytest.mark.parametrize('out_format', ['remote', 'tuple']) -@pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) -@pytest.mark.parametrize('keep_local', [False, True]) -@pytest.mark.remote -def test_merge_index_from_root_remote(manual_integration_dir: Any, out_format: str, - n_partitions: int, keep_local: bool, scheme: str): - from decimal import Decimal - - from pyspark.sql import SparkSession - from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType - - from streaming.base.converters import dataframeToMDS - - if out_format == 'remote': - _, out = manual_integration_dir(scheme) - else: - out = manual_integration_dir(scheme) + out, _ = local_remote_dir spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ From 9f2dd04d0aeca73caceeada4b2ee02222a5568b5 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Mon, 9 Oct 2023 22:28:38 -0700 Subject: [PATCH 51/59] Fix lints --- tests/base/converters/test_dataframe_to_mds.py | 4 ---- tests/test_util.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index e4fbb95a7..ede834de1 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -3,9 +3,7 @@ import json import os -import shutil from decimal import Decimal -from tempfile import mkdtemp from typing import Any, Tuple import pytest @@ -13,7 +11,6 @@ from pyspark.sql.functions import col from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType -from streaming import MDSWriter from streaming.base.converters import dataframeToMDS MY_PREFIX = 'train' @@ -192,4 +189,3 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer else: assert not os.path.exists(os.path.join( out, 'index.json')), 'merged index is created when merge_index=False' - diff --git a/tests/test_util.py b/tests/test_util.py index 539e91f6c..5602b2268 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -3,12 +3,11 @@ import json import os -import shutil import tempfile import time import urllib.parse from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import pytest @@ -28,6 +27,7 @@ os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls + @pytest.mark.parametrize(('text', 'expected_output'), [('hello,world', ['hello', 'world']), ('hello', ['hello']), ('', [])]) def test_get_list_arg(text: str, expected_output: List[Optional[str]]): @@ -247,7 +247,7 @@ def not_merged_index(index_file_path: str, out: str): @pytest.mark.parametrize('n_partitions', [1, 2, 3, 4]) @pytest.mark.parametrize('keep_local', [False, True]) -def test_merge_index_from_root_local(local_remote_dir: [Tuple[str, str]], n_partitions: int, +def test_merge_index_from_root_local(local_remote_dir: Tuple[str, str], n_partitions: int, keep_local: bool): from decimal import Decimal From 013a97bdca66824b3cb547c17d2a34870054c2e7 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 10 Oct 2023 10:08:24 -0700 Subject: [PATCH 52/59] Add specific exceptions to oci list_objects --- streaming/base/storage/upload.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index ffbfef33a..462e19515 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -592,8 +592,16 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: if not next_start_with: response_complete = True return object_names - except Exception as _: - return [] + except Exception as e: + if isinstance(e, oci.exceptions.ServiceError): + if e.status == 404: # type: ignore + if e.code == 'ObjectNotFound': # type: ignore + raise FileNotFoundError(f'Object {bucket_name}/{prefix} not found. {e.message}') from e # type: ignore + if e.code == 'BucketNotFound': # type: ignore + raise ValueError(f'Bucket {bucket_name} not found. {e.message}') from e # type: ignore + raise e + raise e + return [] class AzureUploader(CloudUploader): From 2c214c8f92be5b0c426ae3a680da81d31cc33624 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 10 Oct 2023 10:08:44 -0700 Subject: [PATCH 53/59] Fix comments --- streaming/base/util.py | 15 ++++++++------- tests/test_util.py | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 582725003..1ff371db5 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -215,7 +215,8 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: f'`pip install \'mosaicml-streaming[{package_name}]\'`.' -def merge_index(*args, **kwargs): # pyright: ignore +def merge_index(*args: Any, **kwargs: Any): + """Redirect to one of two merge_index functions based on arguments""" if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: return _merge_index_from_list(*args, **kwargs) elif (isinstance(args[0], str) or @@ -235,10 +236,10 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], each element can take the form of a single path string or a tuple string. The pattern of index_file_urls and corresponding reaction is one of: - 1. All urls are str (local). All urls are accessible locally -> no download - 2. All urls are tuple (local, remote). All urls are accessible locally -> no download - 3. All urls are tuple (local, remote). Download url thtat is not accessible locally - 4. All urls are str (remote) -> download all + 1. All URLS are str (local). All URLS are accessible locally -> no download + 2. All URLS are tuple (local, remote). All URLS are accessible locally -> no download + 3. All URLS are tuple (local, remote). Download URL that is not accessible locally + 4. All URLS are str (remote) -> download all out (Union[str, Tuple[str, str]]): path to put the merged index file keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` @@ -248,7 +249,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], from streaming.base.storage.upload import CloudUploader if not index_file_urls or not out: - logger.warning('Need to specify both index_file_urls and out. No index merged') + logger.warning('Need to specify both `index_file_urls` and `out`. No index merged') return # This is the index json file name, e.g., it is index.json as of 0.6.0 @@ -282,7 +283,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], if scheme == '' and bucket == '' and path == '': raise FileNotFoundError( f'Check data availability! local index {url[0]} is not accessible.' + - 'remote index {url[1]} does not have a valid url format') + f'remote index {url[1]} does not have a valid url format') dest = os.path.join(temp_root, path.lstrip('/')) try: diff --git a/tests/test_util.py b/tests/test_util.py index 5602b2268..971b96791 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -132,7 +132,7 @@ def integrity_check(out: Union[str, Tuple[str, str]], Args: out (Union[str, Tuple[str,str]]): folder that merged index.json resides keep_local: whether to check local file - expected_n_shard_files (int): If -1, find the number in out with get_expected() + expected_n_shard_files (int): If -1, find the number in `out` with get_expected() """ def get_expected(mds_root: str): @@ -172,8 +172,9 @@ def get_expected(mds_root: str): @pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3]) @pytest.mark.parametrize('keep_local', [True, False]) +@pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://']) def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, - index_file_urls_pattern: int): + index_file_urls_pattern: int, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: 1. All urls are str (local). All urls are accessible locally -> no download 2. All urls are str (local). At least one url is unaccessible locally -> Error @@ -196,7 +197,6 @@ def not_merged_index(index_file_path: str, out: str): local, _ = local_remote_dir mds_out = out = local - scheme = 's3://' spark = SparkSession.builder.getOrCreate() # pyright: ignore schema = StructType([ From 224cba624a1c2ed32cfb09e49b6680087f9042ad Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 10 Oct 2023 10:09:11 -0700 Subject: [PATCH 54/59] Add deprecated warning for dataframeToMDS --- streaming/base/converters/dataframe_to_mds.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 8749bbdec..5eb4a4840 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -124,6 +124,14 @@ def dataframeToMDS(dataframe: DataFrame, mds_kwargs: Optional[Dict[str, Any]] = None, udf_iterable: Optional[Callable] = None, udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]: + logger.warning("This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.") + return dataframe_to_mds(dataframe, merge_index, mds_kwargs, udf_iterable, udf_kwargs) + +def dataframe_to_mds(dataframe: DataFrame, + merge_index: bool = True, + mds_kwargs: Optional[Dict[str, Any]] = None, + udf_iterable: Optional[Callable] = None, + udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]: """Execute a spark dataframe to MDS conversion process. This method orchestrates the conversion of a spark dataframe into MDS format by processing the From d806cbf5b6773d4fb668b5b88ec6e04d8c5ce4d6 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 10 Oct 2023 10:32:09 -0700 Subject: [PATCH 55/59] Fix remote url for /Volume --- streaming/base/util.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index 1ff371db5..7aec88b1a 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -373,7 +373,13 @@ def not_merged_index(index_file_path: str, out: str): remote_index_files = [] for file in cu.list_objects(): if file.endswith(get_index_basename()) and not_merged_index(file, cu.remote): - remote_index_files.append(obj.scheme + '://' + os.path.join(obj.netloc, file)) + join_char = '//' + if obj.scheme == 'dbfs': + path = Path(cu.remote) + prefix = os.path.join(path.parts[0], path.parts[1]) + if prefix == 'dbfs:/Volumes': + join_char = '/' + remote_index_files.append(obj.scheme + join_char + os.path.join(obj.netloc, file)) if len(local_index_files) == len(remote_index_files): _merge_index_from_list(list(zip(local_index_files, remote_index_files)), out, From 0944aad03697741516b9fc868a0c1844bdb3cb49 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Tue, 10 Oct 2023 14:18:28 -0700 Subject: [PATCH 56/59] Fix lints --- streaming/base/converters/dataframe_to_mds.py | 9 ++++++++- streaming/base/storage/upload.py | 9 ++++++--- streaming/base/util.py | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 5eb4a4840..d0774254e 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -124,9 +124,16 @@ def dataframeToMDS(dataframe: DataFrame, mds_kwargs: Optional[Dict[str, Any]] = None, udf_iterable: Optional[Callable] = None, udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]: - logger.warning("This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.") + """Deprecated API Signature. + + To be replaced by dataframe_to_mds + """ + logger.warning( + 'This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.' + ) return dataframe_to_mds(dataframe, merge_index, mds_kwargs, udf_iterable, udf_kwargs) + def dataframe_to_mds(dataframe: DataFrame, merge_index: bool = True, mds_kwargs: Optional[Dict[str, Any]] = None, diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index 462e19515..c100c6a1d 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -593,12 +593,15 @@ def list_objects(self, prefix: Optional[str] = None) -> Optional[List[str]]: response_complete = True return object_names except Exception as e: - if isinstance(e, oci.exceptions.ServiceError): + if isinstance(e, oci.exceptions.ServiceError): # type: ignore if e.status == 404: # type: ignore if e.code == 'ObjectNotFound': # type: ignore - raise FileNotFoundError(f'Object {bucket_name}/{prefix} not found. {e.message}') from e # type: ignore + raise FileNotFoundError( + f'Object {bucket_name}/{prefix} not found. {e.message}' # type: ignore + ) from e # type: ignore if e.code == 'BucketNotFound': # type: ignore - raise ValueError(f'Bucket {bucket_name} not found. {e.message}') from e # type: ignore + raise ValueError( + f'Bucket {bucket_name} not found. {e.message}') from e # type: ignore raise e raise e return [] diff --git a/streaming/base/util.py b/streaming/base/util.py index 7aec88b1a..86090f9b7 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -216,7 +216,7 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: def merge_index(*args: Any, **kwargs: Any): - """Redirect to one of two merge_index functions based on arguments""" + """Redirect to one of two merge_index functions based on arguments.""" if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: return _merge_index_from_list(*args, **kwargs) elif (isinstance(args[0], str) or From f4a429bb192e0a0e80650b39be46cc26b5ce1761 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Wed, 11 Oct 2023 22:58:36 -0700 Subject: [PATCH 57/59] Address comments --- streaming/base/converters/dataframe_to_mds.py | 4 ++-- streaming/base/util.py | 6 ++++-- tests/base/converters/test_dataframe_to_mds.py | 6 ------ tests/test_util.py | 10 +++++----- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index d0774254e..1d62ef4e7 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -129,8 +129,8 @@ def dataframeToMDS(dataframe: DataFrame, To be replaced by dataframe_to_mds """ logger.warning( - 'This signature is deprecated. Use dataframe_to_mds with the same arguments going forward.' - ) + 'The DataframeToMDS signature has been deprecated and will be removed in Streaming 0.8. ' + + 'Use dataframe_to_mds with the same arguments going forward') return dataframe_to_mds(dataframe, merge_index, mds_kwargs, udf_iterable, udf_kwargs) diff --git a/streaming/base/util.py b/streaming/base/util.py index 86090f9b7..1933a258c 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -336,10 +336,12 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = Args: out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. :A local directory, merge index happens locally - :A remote directory, download all the sub-directories index.json in a temporary sub-directories, merge locally, and then upload it to out location + :A remote directory, download all the sub-directories index.json in a temporary + sub-directories, merge locally, and then upload it to out location :A (local_dir, remote_dir), check if sub-directories index.json file present locally If yes, then merge locally and upload to remote_dir . - If not, download all the sub-directories index.json from remote to local , merge locally, and upload to remote_dir . + If not, download all the sub-directories index.json from remote to local, + merge locally, and upload to remote_dir . keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` """ from streaming.base.storage.upload import CloudUploader diff --git a/tests/base/converters/test_dataframe_to_mds.py b/tests/base/converters/test_dataframe_to_mds.py index ede834de1..a87309beb 100644 --- a/tests/base/converters/test_dataframe_to_mds.py +++ b/tests/base/converters/test_dataframe_to_mds.py @@ -13,12 +13,6 @@ from streaming.base.converters import dataframeToMDS -MY_PREFIX = 'train' -MY_BUCKET = { - 'gs://': 'testing-bucket', - 's3://': 'testing-bucket', - 'oci://': 'testing-bucket', -} os.environ[ 'OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES' # set to yes to all fork process in spark calls diff --git a/tests/test_util.py b/tests/test_util.py index 971b96791..aa107cc17 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -176,11 +176,11 @@ def get_expected(mds_root: str): def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool, index_file_urls_pattern: int, scheme: str): """Validate the final merge index json for following patterns of index_file_urls: - 1. All urls are str (local). All urls are accessible locally -> no download - 2. All urls are str (local). At least one url is unaccessible locally -> Error - 3. All urls are tuple (local, remote). All urls are accessible locally -> no download - 4. All urls are tuple (local, remote). At least one url is not accessible locally -> download all - 5. All urls are str (remote) -> download all + 1. All URLs are str (local). All URLs are accessible locally -> no download + 2. All URLs are str (local). At least one url is unaccessible locally -> Error + 3. All URLs are tuple (local, remote). All URLs are accessible locally -> no download + 4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all + 5. All URLs are str (remote) -> download all """ from decimal import Decimal From 5c92d20f159e9b4a10bfdb2792925b6b78f874d2 Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 13 Oct 2023 11:26:32 -0700 Subject: [PATCH 58/59] Update doc string for wiki --- docs/source/conf.py | 1 + streaming/base/util.py | 47 +++++++++++++++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index e37c15bf7..e25dc24ba 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -371,6 +371,7 @@ def _modules_to_rst() -> List[types.ModuleType]: streaming.base.shared, streaming.base.shuffle, streaming.base.storage, + streaming.base.util, streaming.base.world, ] exclude_modules: List[types.Module] = [streaming.base, streaming._version] diff --git a/streaming/base/util.py b/streaming/base/util.py index 1933a258c..e7a2de0ed 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -30,7 +30,7 @@ __all__ = [ 'get_list_arg', 'wait_for_file_to_exist', 'bytes_to_int', 'number_abbrev_to_int', - 'clean_stale_shared_memory', 'get_import_exception_message', 'retry' + 'clean_stale_shared_memory', 'get_import_exception_message', 'merge_index', 'retry' ] @@ -216,11 +216,38 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: def merge_index(*args: Any, **kwargs: Any): - """Redirect to one of two merge_index functions based on arguments.""" + r"""Merge index.json from partitions to form a global index.json. + + This can be called as + + merge_index(index_file_urls, out, keep_local, download_timeout) + + merge_index(out, keep_local, download_timeout) + + The first signature takes in a list of index files URLs of MDS partitions. + The second takes the root of a MDS dataset and parse the partition folders from there. + + Args: + index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. + Each element can take the form of a single path string or a tuple string. + + 1. If ``index_file_urls`` is a List of local URLs, merge locally without download. + 2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. + 3. If ``index_file_urls`` is a List of remote URLs, download all and merge. + + out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file + + 1. A local directory, merge index happens locally. + 2. A remote directory, download all the sub-directories index.json, merge locally and upload. + 3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not. + + keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``. + download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + """ if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: return _merge_index_from_list(*args, **kwargs) elif (isinstance(args[0], str) or - isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2]: + isinstance(args[0], tuple)) and len(args) + len(kwargs) in [1, 2, 3]: return _merge_index_from_root(*args, **kwargs) raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') @@ -330,7 +357,9 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], shutil.rmtree(cu.local, ignore_errors=True) -def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = True) -> None: +def _merge_index_from_root(out: Union[str, Tuple[str, str]], + keep_local: bool = True, + download_timeout: int = 60) -> None: """Merge index.json given the root of MDS dataset. Write merged index to the root folder. Args: @@ -343,6 +372,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], keep_local: bool = If not, download all the sub-directories index.json from remote to local, merge locally, and upload to remote_dir . keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` + download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. """ from streaming.base.storage.upload import CloudUploader @@ -386,15 +416,18 @@ def not_merged_index(index_file_path: str, out: str): _merge_index_from_list(list(zip(local_index_files, remote_index_files)), out, keep_local=keep_local, - download_timeout=60) + download_timeout=download_timeout) else: _merge_index_from_list(remote_index_files, out, keep_local=keep_local, - download_timeout=60) + download_timeout=download_timeout) return - _merge_index_from_list(local_index_files, out, keep_local=keep_local, download_timeout=60) + _merge_index_from_list(local_index_files, + out, + keep_local=keep_local, + download_timeout=download_timeout) @overload From 5e44bbf189c2f16174725cdf251d0be76575c54b Mon Sep 17 00:00:00 2001 From: Xiaohan Zhang Date: Fri, 13 Oct 2023 11:38:07 -0700 Subject: [PATCH 59/59] small change --- streaming/base/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/streaming/base/util.py b/streaming/base/util.py index e7a2de0ed..bb3a52d7e 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -242,7 +242,7 @@ def merge_index(*args: Any, **kwargs: Any): 3. A tuple (local_dir, remote_dir), check if local index.json exist, download if not. keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``. - download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. """ if isinstance(args[0], list) and len(args) + len(kwargs) in [2, 3, 4]: return _merge_index_from_list(*args, **kwargs) @@ -270,7 +270,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], out (Union[str, Tuple[str, str]]): path to put the merged index file keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. """ from streaming.base.storage.download import download_file from streaming.base.storage.upload import CloudUploader @@ -372,7 +372,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], If not, download all the sub-directories index.json from remote to local, merge locally, and upload to remote_dir . keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` - download_timeout (int): The allowed time for downloading each json file. Defaults to 60s. + download_timeout (int): The allowed time for downloading each json file. Defaults to 60. """ from streaming.base.storage.upload import CloudUploader