Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1348621 Fix bug and add support for DataFrame key in getitem #1445

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3892,6 +3892,10 @@ def __getitem__(self, key):
if isinstance(self, pd.Series):
return self.loc[key]

# Sometimes the result of a callable is a DataFrame (e.g. df[df > 0]) - use where.
sfc-gh-vbudati marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(key, pd.DataFrame):
sfc-gh-azhan marked this conversation as resolved.
Show resolved Hide resolved
return self.where(cond=key)

# If the object is a boolean list-like object, use .loc[key] to filter index.
# The if statement is structured this way to avoid calling dtype and reduce query count.
if isinstance(key, pd.Series):
Expand Down
41 changes: 26 additions & 15 deletions tests/integ/modin/frame/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,24 +367,35 @@ def test_df_getitem_with_multiindex(
)


@sql_count_checker(query_count=0)
def test_getitem_lambda_dataframe():
@sql_count_checker(query_count=1)
def test_df_getitem_lambda_dataframe():
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
snow_df = pd.DataFrame(data)

# TODO SNOW-980818: Raise ValueError until dataframe key support is implemented
# The ValueError is being thrown by _validate_get_locator_key(col_key) in indexer.py
with pytest.raises(ValueError):
snow_df[lambda x: x < 2]
eval_snowpark_pandas_result(*create_test_dfs(data), lambda df: df[lambda x: x < 2])


@sql_count_checker(query_count=1)
def test_getitem_lambda_series():
data = {"a": 1, "b": 2, "c": 3}
snow_ser = pd.Series(data)
native_ser = native_pd.Series(data)
def test_df_getitem_boolean_df_comparator():
sfc-gh-azhan marked this conversation as resolved.
Show resolved Hide resolved
"""
DataFrame keys (as a result of callables) are valid for getitem but not loc and iloc get.
"""
eval_snowpark_pandas_result(
*create_test_dfs(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
sfc-gh-vbudati marked this conversation as resolved.
Show resolved Hide resolved
),
lambda df: df[df > 7]
)

def helper(ser):
return ser[lambda x: x < 2]

eval_snowpark_pandas_result(snow_ser, native_ser, helper)
@sql_count_checker(query_count=1)
def test_df_getitem_boolean_df_comparator_datetime_index():
"""
Based on bug from SNOW-1348621.
Code adapted from the pandas 10 minute quick start (https://pandas.pydata.org/docs/user_guide/10min.html).
"""
dates = native_pd.date_range("20130101", periods=6)
data = np.random.randn(6, 4)
native_df = native_pd.DataFrame(data, index=dates, columns=list("ABCD"))
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df[df > 0], check_freq=False
)
14 changes: 13 additions & 1 deletion tests/integ/modin/series/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_series_getitem_with_slice(
[("foo", "one"), ("bar", "two")],
],
)
def test_df_getitem_with_multiindex(
def test_series_getitem_with_multiindex(
key, default_index_native_series, multiindex_native
):
expected_join_count = 0 if isinstance(key, slice) or isinstance(key, str) else 1
Expand All @@ -201,3 +201,15 @@ def test_df_getitem_with_multiindex(
lambda ser: ser[key],
check_index_type=False,
)


@sql_count_checker(query_count=1)
def test_series_getitem_lambda_series():
sfc-gh-azhan marked this conversation as resolved.
Show resolved Hide resolved
data = {"a": 1, "b": 2, "c": 3, "d": -1, "e": 0, "f": 10}
snow_ser = pd.Series(data)
native_ser = native_pd.Series(data)

def helper(ser):
return ser[lambda x: x < 2]

eval_snowpark_pandas_result(snow_ser, native_ser, helper)
Loading