Skip to content

Commit

Permalink
Add merge index file utility (#449)
Browse files Browse the repository at this point in the history
* First commit

* Add a naive mds datasts

* fix lints

* Fix

* Change dataframeToMDS API to use merge_util helper

* Fix unit tests

* Fix tests

* Fix lints

* Address a few comments

* update

* updates

* Address comments

* update unit tests

* Update tests

* unit tests + pre-commit ok

* Add list objects for oci, gs, s3

* fix tests

* Fix lints

* list_objects returns only basename

* Fix lints

* fix bugs in list_objects

* updates

* Fix lints

* use new list_objects

* Fix lints

* remove

* Add merge_index

* remove materialized test dataset

* Change do_merge_index to merge_index_from_list

* Fix lints

* Change merge_index to auto_merge_index to avoid duplicate naming

* update pytest yaml

* update

* update

* Fix lints

* Make merge_index a wrapper

* add print

* Change fail msg for missing local file and invalid remote url

* update msg

* remove print

* Fix lints

* Add warning msg for exist_ok=True

* Address comments

* fix lints

* Turn off manual integratin

* Address comments

* Update

* updates

* Fix lints

* remove integration tests

* Fix lints

* Add specific exceptions to oci list_objects

* Fix comments

* Add deprecated warning for dataframeToMDS

* Fix remote url for /Volume

* Fix lints

* Address comments

---------

Co-authored-by: Karan Jariwala <[email protected]>
  • Loading branch information
XiaohanZhangCMU and karan6181 authored Oct 13, 2023
1 parent 91eccbb commit 5a5fa6f
Show file tree
Hide file tree
Showing 7 changed files with 717 additions and 290 deletions.
15 changes: 8 additions & 7 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ jobs:
id: tests
run: |
set -ex
pytest --splits 7 --group 1 --cov-fail-under=10
pytest --splits 7 --group 2 --cov-fail-under=10
pytest --splits 7 --group 3 --cov-fail-under=10
pytest --splits 7 --group 4 --cov-fail-under=10
pytest --splits 7 --group 5 --cov-fail-under=10
pytest --splits 7 --group 6 --cov-fail-under=10
pytest --splits 7 --group 7 --cov-fail-under=10
pytest --splits 8 --group 1 --cov-fail-under=10
pytest --splits 8 --group 2 --cov-fail-under=10
pytest --splits 8 --group 3 --cov-fail-under=10
pytest --splits 8 --group 4 --cov-fail-under=10
pytest --splits 8 --group 5 --cov-fail-under=10
pytest --splits 8 --group 6 --cov-fail-under=10
pytest --splits 8 --group 7 --cov-fail-under=10
pytest --splits 8 --group 8 --cov-fail-under=10
98 changes: 40 additions & 58 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@

"""A utility to convert spark dataframe to MDS."""

import json
import logging
import os
import shutil
from collections.abc import Iterable
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import pandas as pd

from streaming.base.util import get_import_exception_message
from streaming.base.util import merge_index as do_merge_index

try:
from pyspark import TaskContext
Expand Down Expand Up @@ -119,52 +119,26 @@ def map_spark_dtype(spark_data_type: Any) -> str:
return schema_dict


def do_merge_index(partitions: Iterable, mds_path: Union[str, Tuple[str, str]]) -> None:
"""Merge index.json from partitions into one for streaming.
Args:
partitions (Iterable): partitions that contain pd.DataFrame
mds_path (Union[str, Tuple[str, str]]): (str,str)=(local,remote), str = local or remote
based on parse_uri(url) result
"""
if not partitions:
logger.warning('No partitions exist, no index merged')
return

shards = []

for row in partitions:
mds_partition_index = f'{row.mds_path}/{get_index_basename()}'
mds_partition_basename = os.path.basename(row.mds_path)
obj = json.load(open(mds_partition_index))
for i in range(len(obj['shards'])):
shard = obj['shards'][i]
for key in ('raw_data', 'zip_data'):
if shard.get(key):
basename = shard[key]['basename']
obj['shards'][i][key]['basename'] = os.path.join(mds_partition_basename,
basename)
shards += obj['shards']

obj = {
'version': 2,
'shards': shards,
}

if isinstance(mds_path, str):
mds_index = os.path.join(mds_path, get_index_basename())
else:
mds_index = os.path.join(mds_path[0], get_index_basename())

with open(mds_index, 'w') as out:
json.dump(obj, out)


def dataframeToMDS(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
"""Deprecated API Signature.
To be replaced by dataframe_to_mds
"""
logger.warning(
'The DataframeToMDS signature has been deprecated and will be removed in Streaming 0.8. ' +
'Use dataframe_to_mds with the same arguments going forward')
return dataframe_to_mds(dataframe, merge_index, mds_kwargs, udf_iterable, udf_kwargs)


def dataframe_to_mds(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
"""Execute a spark dataframe to MDS conversion process.
This method orchestrates the conversion of a spark dataframe into MDS format by processing the
Expand Down Expand Up @@ -194,19 +168,20 @@ def dataframeToMDS(dataframe: DataFrame,
"""

def write_mds(iterator: Iterable):
"""Worker node writes iterable to MDS datasets locally."""
context = TaskContext.get()

if context is not None:
id = context.taskAttemptId()
else:
raise RuntimeError('TaskContext.get() returns None')

if isinstance(mds_path, str): # local
output = os.path.join(mds_path, f'{id}')
out_file_path = output
if mds_path[1] == '': # only local
output = os.path.join(mds_path[0], f'{id}')
partition_path = (output, '')
else:
output = (os.path.join(mds_path[0], f'{id}'), os.path.join(mds_path[1], f'{id}'))
out_file_path = output[0]
partition_path = output

if mds_kwargs:
kwargs = mds_kwargs.copy()
Expand All @@ -215,7 +190,7 @@ def write_mds(iterator: Iterable):
kwargs = {}

if merge_index:
kwargs['keep_local'] = True # need to keep local to do merge
kwargs['keep_local'] = True # need to keep workers' locals to do merge

count = 0

Expand All @@ -237,10 +212,17 @@ def write_mds(iterator: Iterable):
raise RuntimeError(f'failed to write sample: {sample}') from ex
count += 1

yield pd.concat(
[pd.Series([out_file_path], name='mds_path'),
pd.Series([count], name='fail_count')],
axis=1)
yield pd.concat([
pd.Series([os.path.join(partition_path[0], get_index_basename())],
name='mds_path_local'),
pd.Series([
os.path.join(partition_path[1], get_index_basename())
if partition_path[1] != '' else ''
],
name='mds_path_remote'),
pd.Series([count], name='fail_count')
],
axis=1)

if dataframe is None or dataframe.isEmpty():
raise ValueError(f'Input dataframe is None or Empty!')
Expand Down Expand Up @@ -275,25 +257,25 @@ def write_mds(iterator: Iterable):
keep_local = False if 'keep_local' not in mds_kwargs else mds_kwargs['keep_local']
cu = CloudUploader.get(out, keep_local=keep_local)

# Fix output format as mds_path: Tuple => remote Str => local only
# Fix output format as mds_path: Tuple(local, remote)
if cu.remote is None:
mds_path = cu.local
mds_path = (cu.local, '')
else:
mds_path = (cu.local, cu.remote)

# Prepare partition schema
result_schema = StructType([
StructField('mds_path', StringType(), False),
StructField('mds_path_local', StringType(), False),
StructField('mds_path_remote', StringType(), False),
StructField('fail_count', IntegerType(), False)
])
partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect()

if merge_index:
do_merge_index(partitions, mds_path)
index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions]
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)

if cu.remote is not None:
if merge_index:
cu.upload_file(get_index_basename())
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)

Expand Down
Loading

0 comments on commit 5a5fa6f

Please sign in to comment.