Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-1649172]: Fix loc set when setting DataFrame row with Series value #2213

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3151ed7
[SNOW-1649172]: Fix `loc` set when setting DataFrame row with Series …
sfc-gh-rdurrani Sep 3, 2024
2bd792f
Add some more tests (including some negatives)
sfc-gh-rdurrani Sep 3, 2024
9e2a26d
Fix tests
sfc-gh-rdurrani Sep 3, 2024
66f01bc
minor changes
sfc-gh-vbudati Sep 5, 2024
c3b9582
fix test
sfc-gh-vbudati Sep 9, 2024
a9aceb9
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-vbudati Sep 9, 2024
c18ae1f
fix bug
sfc-gh-vbudati Sep 9, 2024
c159e3a
fix tests
sfc-gh-vbudati Sep 11, 2024
2289960
add example
sfc-gh-vbudati Sep 11, 2024
89401a8
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Sep 18, 2024
25ccbb9
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Sep 19, 2024
8f75bec
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Sep 19, 2024
f8797d8
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Sep 19, 2024
c252eb5
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Sep 19, 2024
247fb02
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Sep 27, 2024
81d8752
Address review comments
sfc-gh-rdurrani Oct 2, 2024
ad817c4
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Oct 2, 2024
88b8f86
Address potential bug
sfc-gh-rdurrani Oct 2, 2024
230a015
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Oct 3, 2024
8e4f3e7
Add tests
sfc-gh-rdurrani Oct 3, 2024
6f4ba8e
Fix tests
sfc-gh-rdurrani Oct 3, 2024
15c48e3
Address review comments
sfc-gh-rdurrani Oct 3, 2024
ecac5a0
Update docs
sfc-gh-rdurrani Oct 3, 2024
a3a1fb0
Refactor into helper method
sfc-gh-rdurrani Oct 3, 2024
6ade2b9
Update test coverage
sfc-gh-rdurrani Oct 4, 2024
13a54ab
Merge branch 'main' into rdurrani-SNOW-1649172
sfc-gh-rdurrani Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
- Fixed `inplace` argument for Series objects derived from DataFrame columns.
- Fixed a bug where `Series.reindex` and `DataFrame.reindex` did not update the result index's name correctly.
- Fixed a bug where `Series.take` did not error when `axis=1` was specified.
- Fixed `loc` set when setting row of DataFrame with Series value.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

#### Behavior Change

Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/modin/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,9 @@ def __setitem__(
)
if item_is_2d_array:
item = pd.DataFrame(item)
frame_is_df_and_item_is_series = isinstance(item, pd.Series) and isinstance(
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
self.df, pd.DataFrame
)
item = item._query_compiler if isinstance(item, BasePandasDataset) else item
new_qc = self.qc.set_2d_labels(
index,
Expand All @@ -1049,6 +1052,7 @@ def __setitem__(
matching_item_columns_by_label=matching_item_columns_by_label,
matching_item_rows_by_label=matching_item_rows_by_label,
index_is_bool_indexer=index_is_bool_indexer,
frame_is_df_and_item_is_series=frame_is_df_and_item_is_series,
)

self.df._update_inplace(new_query_compiler=new_qc)
Expand Down
49 changes: 44 additions & 5 deletions src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,6 +1891,7 @@ def _set_2d_labels_helper_for_frame_item(
assert len(index.data_column_snowflake_quoted_identifiers) == len(
item.index_column_snowflake_quoted_identifiers
), "TODO: SNOW-966427 handle it well in multiindex case"

if not matching_item_rows_by_label:
index = index.ensure_row_position_column()
left_on = [index.row_position_snowflake_quoted_identifier]
Expand Down Expand Up @@ -2127,6 +2128,7 @@ def set_frame_2d_labels(
matching_item_rows_by_label: bool,
index_is_bool_indexer: bool,
deduplicate_columns: bool,
frame_is_df_and_item_is_series: bool,
) -> InternalFrame:
"""
Helper function to handle the general loc set functionality. The general idea here is to join the key from ``index``
Expand All @@ -2153,6 +2155,7 @@ def set_frame_2d_labels(
index_is_bool_indexer: if True, the index is a boolean indexer. Note we only handle boolean indexer with
item is a SnowflakeQueryCompiler here.
deduplicate_columns: if True, deduplicate columns from ``columns``.
frame_is_df_and_item_is_series: Whether item is from a Series object and is being assigned to a DataFrame object
Returns:
New frame where values have been set
"""
Expand Down Expand Up @@ -2215,6 +2218,36 @@ def set_frame_2d_labels(
index_is_frame = isinstance(index, InternalFrame)
item_is_frame = isinstance(item, InternalFrame)
item_is_scalar = is_scalar(item)
original_index = index
# If `item` is from a Series (rather than a Dataframe), flip the series item values to apply them
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
# across columns rather than rows.
if frame_is_df_and_item_is_series and (columns == slice(None) or len(columns) > 1): # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you wrap it into a function and use function name to brief what this method does?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean (columns == slice(None) or len(columns) > 1)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this type: ignore[arg-type] actually indicate something is wrong. You didn't consider all type cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is checking to see if more than one column is being set. As for the arg-type, I think that is because its ignoring if the columns is a SnowflakeQueryCompiler? I've added a test for that case, and will fix it!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done in _set_2d_labels_helper_for_frame_item

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think it needs to be done in this method, since we need to modify item before the map is created (which is passed into _set_2d_labels_helper_for_frame_item, and we need the modified item later on in this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can move this into the conditional for if item_is_frame though!

# If columns is slice(None), we are setting all columns in the InternalFrame.
matching_item_columns_by_label = True
col_len = (
len(internal_frame.data_column_snowflake_quoted_identifiers)
if columns == slice(None)
else len(columns) # type: ignore[arg-type]
)
item = get_item_series_as_single_row_frame(
item, col_len, move_index_to_cols=True
)

if is_scalar(index):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if index is not scalar?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If index is not scalar, we don't have to append it to the item to match index - it should either be slice(None) or an internalframe, which we handle in the rest of the method.

new_item = item.append_column("__index__", pandas_lit(index))
item = InternalFrame.create(
ordered_dataframe=new_item.ordered_dataframe,
data_column_pandas_labels=item.data_column_pandas_labels,
data_column_snowflake_quoted_identifiers=item.data_column_snowflake_quoted_identifiers,
data_column_pandas_index_names=item.data_column_pandas_index_names,
index_column_pandas_labels=item.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=[
new_item.data_column_snowflake_quoted_identifiers[-1]
],
data_column_types=item.cached_data_column_snowpark_pandas_types,
index_column_types=[item.cached_data_column_snowpark_pandas_types[-1]],
)
index = pd.Series([index])._query_compiler._modin_frame

assert not isinstance(index, slice) or index == slice(
None
Expand Down Expand Up @@ -2411,7 +2444,7 @@ def generate_updated_expr_for_existing_col(

if index_is_scalar:
col_obj = iff(
result_frame_index_col.equal_null(pandas_lit(index)),
result_frame_index_col.equal_null(pandas_lit(original_index)),
col_obj,
original_col,
)
Expand Down Expand Up @@ -2470,7 +2503,7 @@ def generate_updated_expr_for_new_col(
return SnowparkPandasColumn(pandas_lit(None), None)
if index_is_scalar:
new_column = iff(
result_frame_index_col.equal_null(pandas_lit(index)),
result_frame_index_col.equal_null(pandas_lit(original_index)),
new_column,
pandas_lit(None),
)
Expand Down Expand Up @@ -2606,7 +2639,6 @@ def set_frame_2d_positional(
index = _get_adjusted_key_frame_by_row_pos_int_frame(internal_frame, index)

assert isinstance(index_data_type, (_IntegralType, BooleanType))

if isinstance(item, InternalFrame):
# If item is Series (rather than a Dataframe), then we need to flip the series item values so they apply across
# columns rather than rows.
Expand Down Expand Up @@ -2920,7 +2952,9 @@ def get_kv_frame_from_index_and_item_frames(


def get_item_series_as_single_row_frame(
item: InternalFrame, num_columns: int
item: InternalFrame,
num_columns: int,
move_index_to_cols: Optional[bool] = False,
) -> InternalFrame:
"""
Get an internal frame that transpose single data column into frame with single row. For example, if the
Expand All @@ -2942,13 +2976,18 @@ def get_item_series_as_single_row_frame(
----------
num_columns: Number of columns in the return frame
item: Item frame that contains a single column of values.
move_index_to_cols: Whether to use the index as the column names.

Returns
-------
Frame containing single row with columns for each row.
"""
item = item.ensure_row_position_column()
item_series_pandas_labels = list(range(num_columns))
item_series_pandas_labels = (
list(range(num_columns))
if not move_index_to_cols
else item.index_columns_pandas_index().values
)

# This is a 2 step process.
#
Expand Down
10 changes: 6 additions & 4 deletions src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ def compute_isin_with_series(
slice(None),
[agg_label],
values_series,
False,
True,
False,
False,
matching_item_columns_by_label=False,
matching_item_rows_by_label=True,
index_is_bool_indexer=False,
deduplicate_columns=False,
frame_is_df_and_item_is_series=False,
)

# apply isin operation for all columns except the appended agg_label/agg_identifier column.
Expand Down Expand Up @@ -272,6 +273,7 @@ def compute_isin_with_dataframe(
True,
False,
False,
False,
)

isin_identifiers = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9275,6 +9275,7 @@ def set_2d_labels(
matching_item_rows_by_label: bool,
index_is_bool_indexer: bool,
deduplicate_columns: bool = False,
frame_is_df_and_item_is_series: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Create a new SnowflakeQueryCompiler with indexed columns and rows replaced by item.
Expand All @@ -9299,6 +9300,7 @@ def set_2d_labels(
index_is_bool_indexer: if True, the index is a boolean indexer.
deduplicate_columns: if True, deduplicate columns from ``columns``, e.g., if columns = ["A","A"], only the
second "A" column will be used.
frame_is_df_and_item_is_series: Whether item is from a Series and is being set to a DataFrame object
Returns:
Updated SnowflakeQueryCompiler
"""
Expand All @@ -9324,6 +9326,7 @@ def set_2d_labels(
matching_item_rows_by_label=matching_item_rows_by_label,
index_is_bool_indexer=index_is_bool_indexer,
deduplicate_columns=deduplicate_columns,
frame_is_df_and_item_is_series=frame_is_df_and_item_is_series,
)

return SnowflakeQueryCompiler(result_frame)
Expand Down
65 changes: 65 additions & 0 deletions tests/integ/modin/frame/test_iloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3205,3 +3205,68 @@ def test_raise_set_cell_with_list_like_value_error():
s.iloc[0] = [0, 0]
with pytest.raises(NotImplementedError):
s.to_frame().iloc[0, 0] = [0, 0]


@sql_count_checker(query_count=1, join_count=3)
@pytest.mark.parametrize("index", [list("ABC"), [0, 1, 2]])
def test_df_iloc_set_row_from_series(index):
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC"))
snow_df = pd.DataFrame(native_df)

def ilocset(df):
series = (
pd.Series([1, 4, 9], index=index)
if isinstance(df, pd.DataFrame)
else native_pd.Series([1, 4, 9], index=index)
)
df.iloc[1] = series
return df

eval_snowpark_pandas_result(
snow_df,
native_df,
ilocset,
)


@sql_count_checker(query_count=1, join_count=3)
@pytest.mark.parametrize("index", [[3, 4, 5], [0, 1, 2]])
def test_df_iloc_full_set_row_from_series(index):
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]])
snow_df = pd.DataFrame(native_df)

def ilocset(df):
series = (
pd.Series([1, 4, 9], index=index)
if isinstance(df, pd.DataFrame)
else native_pd.Series([1, 4, 9], index=index)
)
df.iloc[:] = series
return df

eval_snowpark_pandas_result(
snow_df,
native_df,
ilocset,
)


@sql_count_checker(query_count=1, join_count=3)
def test_df_iloc_full_set_row_from_series_int_and_string_indexes():
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC"))
snow_df = pd.DataFrame(native_df)

def ilocset(df):
series = (
pd.Series([1, 4, 9], index=list("ABC"))
if isinstance(df, pd.DataFrame)
else native_pd.Series([1, 4, 9], index=list("ABC"))
)
df.iloc[:] = series
return df

eval_snowpark_pandas_result(
snow_df,
native_df,
ilocset,
)
68 changes: 68 additions & 0 deletions tests/integ/modin/frame/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4072,3 +4072,71 @@ def test_df_loc_get_with_timedelta_and_none_key():
# Compare with an empty DataFrame, since native pandas raises a KeyError.
expected_df = native_pd.DataFrame()
assert_frame_equal(snow_df.loc[None], expected_df, check_column_type=False)


@sql_count_checker(query_count=2, join_count=4)
@pytest.mark.parametrize("index", [list("ABC"), [0, 1, 2]])
def test_df_loc_set_row_from_series(index):
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC"))
snow_df = pd.DataFrame(native_df)

def locset(df):
series = (
pd.Series([1, 4, 9], index=index)
if isinstance(df, pd.DataFrame)
else native_pd.Series([1, 4, 9], index=index)
)
df.loc[1] = series
return df

eval_snowpark_pandas_result(
snow_df,
native_df,
locset,
)


@sql_count_checker(query_count=2, join_count=1)
@pytest.mark.parametrize("index", [[3, 4, 5], [0, 1, 2]])
def test_df_loc_full_set_row_from_series_pandas_errors(index):
native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6]])
snow_df = pd.DataFrame(native_df)

with pytest.raises(ValueError, match="setting an array element with a sequence."):
native_df.loc[:] = native_pd.Series([1, 4, 9], index=index)

def locset(df):
series = (
pd.Series([1, 4, 9], index=index)
if isinstance(df, pd.DataFrame)
else native_pd.Series([1, 4, 9], index=index)
)
if isinstance(df, pd.DataFrame):
df.loc[:] = series
else:
if index == [0, 1, 2]:
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
df.loc[0] = series
df.loc[1] = None
else:
df.loc[[0, 1]] = None
return df

eval_snowpark_pandas_result(
snow_df,
native_df,
locset,
)


@sql_count_checker(query_count=1)
def test_df_loc_full_set_row_from_series_errors():
# We error here because our join columns are an int (item.index)
# and a string (value.index) column respectively, and we do not
# support joins between those.
snow_df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=list("ABC"))

with pytest.raises(
SnowparkSQLException, match="Numeric value 'A' is not recognized"
):
snow_df.loc[:] = pd.Series([1, 4, 9], index=list("ABC"))
snow_df.to_pandas() # Force materialization.
26 changes: 15 additions & 11 deletions tests/integ/modin/series/test_loc.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,19 +943,23 @@ def set_loc_helper(ser):
else native_item_ser
)

expected_join_count = 1 if not start and not stop and not step else 4

with SqlCounter(query_count=1, join_count=expected_join_count):
if slice_len == 0:
# pandas can fail in this case, so we skip call loc for it, see more below in
# test_series_loc_set_key_slice_with_series_item_pandas_bug
set_loc_helper(snow_ser)
# snow_ser should not change when slice_len = 0
if slice_len == 0:
# pandas can fail in this case, so we skip call loc for it, see more below in
# test_series_loc_set_key_slice_with_series_item_pandas_bug
set_loc_helper(snow_ser)
# snow_ser should not change when slice_len = 0
with SqlCounter(query_count=1):
assert_snowpark_pandas_equal_to_pandas(snow_ser, native_ser)
else:
native_res = set_loc_helper(native_ser)
if is_scalar(native_res):
with SqlCounter(query_count=0):
snow_res = set_loc_helper(snow_ser)
assert snow_res == native_res
else:
eval_snowpark_pandas_result(
snow_ser, native_ser, set_loc_helper, inplace=True
)
with SqlCounter(query_count=1, join_count=4):
snow_res = set_loc_helper(snow_ser)
assert_series_equal(snow_res, native_res)


@pytest.mark.parametrize(
Expand Down
Loading