From a52d2ef0ae72cc84ced442a5026d3a69325b6885 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Fri, 26 Apr 2024 15:45:45 -0700 Subject: [PATCH 1/3] Add support for DataFrame key in getitem --- src/snowflake/snowpark/modin/pandas/base.py | 4 +++ tests/integ/modin/frame/test_getitem.py | 31 ++++++++++----------- tests/integ/modin/series/test_getitem.py | 14 +++++++++- 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/snowflake/snowpark/modin/pandas/base.py b/src/snowflake/snowpark/modin/pandas/base.py index c06ebeae1b7..84a0ea1a082 100644 --- a/src/snowflake/snowpark/modin/pandas/base.py +++ b/src/snowflake/snowpark/modin/pandas/base.py @@ -3892,6 +3892,10 @@ def __getitem__(self, key): if isinstance(self, pd.Series): return self.loc[key] + # Sometimes the result of a callable is a DataFrame (e.g. df[df > 0]) - use where. + elif isinstance(key, pd.DataFrame): + return self.where(cond=key) + # If the object is a boolean list-like object, use .loc[key] to filter index. # The if statement is structured this way to avoid calling dtype and reduce query count. if isinstance(key, pd.Series): diff --git a/tests/integ/modin/frame/test_getitem.py b/tests/integ/modin/frame/test_getitem.py index a24082e067d..b70e19869f2 100644 --- a/tests/integ/modin/frame/test_getitem.py +++ b/tests/integ/modin/frame/test_getitem.py @@ -367,24 +367,21 @@ def test_df_getitem_with_multiindex( ) -@sql_count_checker(query_count=0) -def test_getitem_lambda_dataframe(): +@sql_count_checker(query_count=1) +def test_df_getitem_lambda_dataframe(): data = {"a": [1, 2, 3], "b": [4, 5, 6]} - snow_df = pd.DataFrame(data) - - # TODO SNOW-980818: Raise ValueError until dataframe key support is implemented - # The ValueError is being thrown by _validate_get_locator_key(col_key) in indexer.py - with pytest.raises(ValueError): - snow_df[lambda x: x < 2] + eval_snowpark_pandas_result(*create_test_dfs(data), lambda df: df[lambda x: x < 2]) @sql_count_checker(query_count=1) -def test_getitem_lambda_series(): - data = {"a": 1, "b": 2, "c": 3} - snow_ser = pd.Series(data) - native_ser = native_pd.Series(data) - - def helper(ser): - return ser[lambda x: x < 2] - - eval_snowpark_pandas_result(snow_ser, native_ser, helper) +def test_df_getitem_boolean_df_comparator(): + """ + Based on bug from SNOW-1348621. + DataFrame keys (as a result of callables) are valid for getitem but not loc and iloc get. + """ + eval_snowpark_pandas_result( + *create_test_dfs( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] + ), + lambda df: df[df > 0] + ) diff --git a/tests/integ/modin/series/test_getitem.py b/tests/integ/modin/series/test_getitem.py index feb22fdbb4f..08df540961f 100644 --- a/tests/integ/modin/series/test_getitem.py +++ b/tests/integ/modin/series/test_getitem.py @@ -187,7 +187,7 @@ def test_series_getitem_with_slice( [("foo", "one"), ("bar", "two")], ], ) -def test_df_getitem_with_multiindex( +def test_series_getitem_with_multiindex( key, default_index_native_series, multiindex_native ): expected_join_count = 0 if isinstance(key, slice) or isinstance(key, str) else 1 @@ -201,3 +201,15 @@ def test_df_getitem_with_multiindex( lambda ser: ser[key], check_index_type=False, ) + + +@sql_count_checker(query_count=1) +def test_series_getitem_lambda_series(): + data = {"a": 1, "b": 2, "c": 3} + snow_ser = pd.Series(data) + native_ser = native_pd.Series(data) + + def helper(ser): + return ser[lambda x: x < 2] + + eval_snowpark_pandas_result(snow_ser, native_ser, helper) From 86377ad00d24b743fb5db37c813f27e390140fcf Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Fri, 26 Apr 2024 16:03:51 -0700 Subject: [PATCH 2/3] Change the comparisons in tests --- tests/integ/modin/frame/test_getitem.py | 2 +- tests/integ/modin/series/test_getitem.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integ/modin/frame/test_getitem.py b/tests/integ/modin/frame/test_getitem.py index b70e19869f2..8db5f2a3be1 100644 --- a/tests/integ/modin/frame/test_getitem.py +++ b/tests/integ/modin/frame/test_getitem.py @@ -383,5 +383,5 @@ def test_df_getitem_boolean_df_comparator(): *create_test_dfs( [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]] ), - lambda df: df[df > 0] + lambda df: df[df > 7] ) diff --git a/tests/integ/modin/series/test_getitem.py b/tests/integ/modin/series/test_getitem.py index 08df540961f..8d2564571b2 100644 --- a/tests/integ/modin/series/test_getitem.py +++ b/tests/integ/modin/series/test_getitem.py @@ -205,7 +205,7 @@ def test_series_getitem_with_multiindex( @sql_count_checker(query_count=1) def test_series_getitem_lambda_series(): - data = {"a": 1, "b": 2, "c": 3} + data = {"a": 1, "b": 2, "c": 3, "d": -1, "e": 0, "f": 10} snow_ser = pd.Series(data) native_ser = native_pd.Series(data) From 6cc36659085ead775f3b877d6153aeb2b9a0660f Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Fri, 26 Apr 2024 16:44:31 -0700 Subject: [PATCH 3/3] Add WAM bug test --- tests/integ/modin/frame/test_getitem.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/integ/modin/frame/test_getitem.py b/tests/integ/modin/frame/test_getitem.py index 8db5f2a3be1..89bd0580fd9 100644 --- a/tests/integ/modin/frame/test_getitem.py +++ b/tests/integ/modin/frame/test_getitem.py @@ -376,7 +376,6 @@ def test_df_getitem_lambda_dataframe(): @sql_count_checker(query_count=1) def test_df_getitem_boolean_df_comparator(): """ - Based on bug from SNOW-1348621. DataFrame keys (as a result of callables) are valid for getitem but not loc and iloc get. """ eval_snowpark_pandas_result( @@ -385,3 +384,18 @@ def test_df_getitem_boolean_df_comparator(): ), lambda df: df[df > 7] ) + + +@sql_count_checker(query_count=1) +def test_df_getitem_boolean_df_comparator_datetime_index(): + """ + Based on bug from SNOW-1348621. + Code adapted from the pandas 10 minute quick start (https://pandas.pydata.org/docs/user_guide/10min.html). + """ + dates = native_pd.date_range("20130101", periods=6) + data = np.random.randn(6, 4) + native_df = native_pd.DataFrame(data, index=dates, columns=list("ABCD")) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: df[df > 0], check_freq=False + )