Skip to content

Commit

Permalink
Offload exception to mds_write. (#528)
Browse files Browse the repository at this point in the history
* remove exception throwing and let mds_write handle

* fix tests

* update

* update

* fix lints

* fix lints

---------

Co-authored-by: Karan Jariwala <[email protected]>
  • Loading branch information
XiaohanZhangCMU and karan6181 authored Dec 12, 2023
1 parent d449b46 commit 13fb2eb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
27 changes: 5 additions & 22 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def dataframeToMDS(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[str, str]:
"""Deprecated API Signature.
To be replaced by dataframe_to_mds
Expand All @@ -138,7 +138,7 @@ def dataframe_to_mds(dataframe: DataFrame,
merge_index: bool = True,
mds_kwargs: Optional[Dict[str, Any]] = None,
udf_iterable: Optional[Callable] = None,
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[Any, int]:
udf_kwargs: Optional[Dict[str, Any]] = None) -> Tuple[str, str]:
"""Execute a spark dataframe to MDS conversion process.
This method orchestrates the conversion of a spark dataframe into MDS format by processing the
Expand All @@ -157,8 +157,6 @@ def dataframe_to_mds(dataframe: DataFrame,
Returns:
mds_path (str or (str,str)): actual local and remote path were used
fail_count (int): number of records failed to be converted
Notes:
- The method creates a SparkSession if not already available.
- The 'udf_kwargs' dictionaries can be used to pass additional
Expand Down Expand Up @@ -192,8 +190,6 @@ def write_mds(iterator: Iterable):
if merge_index:
kwargs['keep_local'] = True # need to keep workers' locals to do merge

count = 0

with MDSWriter(**kwargs) as mds_writer:
for pdf in iterator:
if udf_iterable is not None:
Expand All @@ -206,11 +202,7 @@ def write_mds(iterator: Iterable):
f'{type(records)}')

for sample in records:
try:
mds_writer.write(sample)
except Exception as ex:
raise RuntimeError(f'failed to write sample: {sample}') from ex
count += 1
mds_writer.write(sample)

yield pd.concat([
pd.Series([os.path.join(partition_path[0], get_index_basename())],
Expand All @@ -219,8 +211,7 @@ def write_mds(iterator: Iterable):
os.path.join(partition_path[1], get_index_basename())
if partition_path[1] != '' else ''
],
name='mds_path_remote'),
pd.Series([count], name='fail_count')
name='mds_path_remote')
],
axis=1)

Expand Down Expand Up @@ -267,7 +258,6 @@ def write_mds(iterator: Iterable):
result_schema = StructType([
StructField('mds_path_local', StringType(), False),
StructField('mds_path_remote', StringType(), False),
StructField('fail_count', IntegerType(), False)
])
partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect()

Expand All @@ -285,11 +275,4 @@ def write_mds(iterator: Iterable):
if not keep_local_files:
shutil.rmtree(cu.local, ignore_errors=True)

sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']

if sum_fail_count > 0:
logger.warning(
f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}')
return mds_path, sum_fail_count
return mds_path
20 changes: 10 additions & 10 deletions tests/base/converters/test_dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def test_end_to_end_conversion_local_nocolumns(self, dataframe: Any, keep_local:
}

with pytest.raises(ValueError, match=f'.*is not supported by MDSWriter.*'):
_, _ = dataframe_to_mds(dataframe.select(col('id'), col('dept'), col('properties')),
merge_index=merge_index,
mds_kwargs=mds_kwargs)
_ = dataframe_to_mds(dataframe.select(col('id'), col('dept'), col('properties')),
merge_index=merge_index,
mds_kwargs=mds_kwargs)

_, _ = dataframe_to_mds(dataframe.select(col('id'), col('dept')),
merge_index=merge_index,
mds_kwargs=mds_kwargs)
_ = dataframe_to_mds(dataframe.select(col('id'), col('dept')),
merge_index=merge_index,
mds_kwargs=mds_kwargs)

if keep_local:
assert len(os.listdir(out)) > 0, f'{out} is empty'
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_end_to_end_conversion_local_decimal(self, decimal_dataframe: Any, use_c
if use_columns:
mds_kwargs['columns'] = user_defined_columns

_, _ = dataframe_to_mds(decimal_dataframe, merge_index=True, mds_kwargs=mds_kwargs)
_ = dataframe_to_mds(decimal_dataframe, merge_index=True, mds_kwargs=mds_kwargs)
assert len(os.listdir(out)) > 0, f'{out} is empty'

def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str, str]):
Expand All @@ -126,7 +126,7 @@ def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str,
'columns': user_defined_columns,
}
with pytest.raises(ValueError, match=f'.*is not a column of input dataframe.*'):
_, _ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs)
_ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs)

user_defined_columns = {'id': 'strr', 'dept': 'str'}

Expand All @@ -135,7 +135,7 @@ def test_user_defined_columns(self, dataframe: Any, local_remote_dir: Tuple[str,
'columns': user_defined_columns,
}
with pytest.raises(ValueError, match=f'.* is not supported by MDSWriter.*'):
_, _ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs)
_ = dataframe_to_mds(dataframe, merge_index=False, mds_kwargs=mds_kwargs)

@pytest.mark.parametrize('keep_local', [True, False])
@pytest.mark.parametrize('merge_index', [True, False])
Expand All @@ -154,7 +154,7 @@ def test_end_to_end_conversion_local(self, dataframe: Any, keep_local: bool, mer
'size_limit': 1 << 26
}

_, _ = dataframe_to_mds(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs)
_ = dataframe_to_mds(dataframe, merge_index=merge_index, mds_kwargs=mds_kwargs)

if keep_local:
assert len(os.listdir(out)) > 0, f'{out} is empty'
Expand Down

0 comments on commit 13fb2eb

Please sign in to comment.