diff --git a/src/snowflake/snowpark/modin/pandas/groupby.py b/src/snowflake/snowpark/modin/pandas/groupby.py index bd9cff241bc..7f20c8c9ad7 100644 --- a/src/snowflake/snowpark/modin/pandas/groupby.py +++ b/src/snowflake/snowpark/modin/pandas/groupby.py @@ -41,6 +41,7 @@ # Snowpark pandas API version from snowflake.snowpark.modin.pandas.series import Series from snowflake.snowpark.modin.pandas.utils import ( + extract_validate_and_try_convert_named_aggs_from_kwargs, raise_if_native_pandas_objects, validate_and_try_convert_agg_func_arg_func_to_str, ) @@ -570,10 +571,17 @@ def aggregate( ErrorMessage.not_implemented( "axis other than 0 is not supported" ) # pragma: no cover - - func = validate_and_try_convert_agg_func_arg_func_to_str( - agg_func=func, obj=self, allow_duplication=True, axis=self._axis - ) + if func is None: + func = extract_validate_and_try_convert_named_aggs_from_kwargs( + obj=self, allow_duplication=True, axis=self._axis, **kwargs + ) + else: + func = validate_and_try_convert_agg_func_arg_func_to_str( + agg_func=func, + obj=self, + allow_duplication=True, + axis=self._axis, + ) if isinstance(func, str): # Using "getattr" here masks possible AttributeError which we throw diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index 4a4ccb215a2..7ad2b192943 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -45,6 +45,7 @@ FactoryDispatcher, ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + AggFuncWithLabel, get_pandas_aggr_func_name, ) from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage @@ -539,6 +540,85 @@ def _try_convert_single_builtin_func_to_str(f): return _try_convert_single_builtin_func_to_str(fn) +def extract_validate_and_try_convert_named_aggs_from_kwargs( + obj: object, allow_duplication: bool, axis: int, **kwargs +) -> AggFuncType: + """ + Attempt to extract pd.NamedAgg (or tuples of the same format) from the kwargs. + + kwargs: dict + The kwargs to extract from. + + Returns: + A dictionary mapping columns to a tuple containing the aggregation to perform, as well + as the pandas label to give the aggregated column. + """ + from snowflake.snowpark.modin.pandas.groupby import SeriesGroupBy + + named_aggs = {} + accepted_keys = [] + columns = obj._query_compiler.columns + for key, value in kwargs.items(): + if isinstance(value, pd.NamedAgg) or ( + isinstance(value, tuple) and len(value) == 2 + ): + if axis == 0: + # If axis == 1, we would need a query to materialize the index to check its existence + # so we defer the error checking to later. + if value[0] not in columns: + raise KeyError(f"Column(s) ['{value[0]}'] do not exist") + + if value[0] in named_aggs: + if not isinstance(named_aggs[value[0]], list): + named_aggs[value[0]] = [named_aggs[value[0]]] + named_aggs[value[0]] += [ + AggFuncWithLabel(func=value[1], pandas_label=key) + ] + else: + named_aggs[value[0]] = AggFuncWithLabel(func=value[1], pandas_label=key) + accepted_keys += [key] + elif isinstance(obj, SeriesGroupBy): + col_name = obj._df._query_compiler.columns[0] + if col_name not in named_aggs: + named_aggs[col_name] = AggFuncWithLabel(func=value, pandas_label=key) + else: + if not isinstance(named_aggs[col_name], list): + named_aggs[col_name] = [named_aggs[col_name]] + named_aggs[col_name] += [AggFuncWithLabel(func=value, pandas_label=key)] + accepted_keys += [key] + + if len(named_aggs.keys()) == 0: + ErrorMessage.not_implemented( + "Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas" + ) + + if any(key not in accepted_keys for key in kwargs.keys()): + # For compatibility with pandas errors. Otherwise, we would just ignore + # those kwargs. + raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).") + + validated_named_aggs = {} + for key, value in named_aggs.items(): + if isinstance(value, list): + validated_named_aggs[key] = [ + AggFuncWithLabel( + func=validate_and_try_convert_agg_func_arg_func_to_str( + v.func, obj, allow_duplication, axis + ), + pandas_label=v.pandas_label, + ) + for v in value + ] + else: + validated_named_aggs[key] = AggFuncWithLabel( + func=validate_and_try_convert_agg_func_arg_func_to_str( + value.func, obj, allow_duplication, axis + ), + pandas_label=value.pandas_label, + ) + return validated_named_aggs + + def validate_and_try_convert_agg_func_arg_func_to_str( agg_func: AggFuncType, obj: object, allow_duplication: bool, axis: int ) -> AggFuncType: @@ -580,12 +660,6 @@ def validate_and_try_convert_agg_func_arg_func_to_str( """ if agg_func is None: - # Snowpark pandas only support func argument at this moment. - # TODO (SNOW-902943): pandas allows usage of NamedAgg in kwargs to configure - # tuples of (columns, agg_func) with rename. For example: - # df.groupby('A').agg(b_min=pd.NamedAgg(column='B', aggfunc='min')), which applies - # min function on column 'B', and uses 'b_min' as the new column name. - # Once supported, refine the check to check both. ErrorMessage.not_implemented( "Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas" ) diff --git a/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md b/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md index ca568d9a234..7d5100ffb17 100644 --- a/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md +++ b/src/snowflake/snowpark/modin/plugin/PANDAS_CHANGELOG.md @@ -21,6 +21,7 @@ ### New Features - Added partial support for `SeriesGroupBy.apply` (where the `SeriesGrouBy` is obtained through `DataFrameGroupBy.__getitem__`). +- Added support for `pd.NamedAgg` in `DataFrameGroupBy.agg` and `SeriesGroupBy.agg`. ## 1.15.0a1 (2024-05-03) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index eb54d15b794..e7dd73de499 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -267,6 +267,19 @@ def _columns_coalescing_idxmax_idxmin_helper( } +class AggFuncWithLabel(NamedTuple): + """ + This class is used to process NamedAgg's internally, and represents an AggFunc that + also includes a label to be used on the column that it generates. + """ + + # The aggregate function + func: AggFuncTypeBase + + # The label to provide the new column produced by `func`. + pandas_label: Hashable + + class AggFuncInfo(NamedTuple): """ Information needed to distinguish between dummy and normal aggregate functions. @@ -278,6 +291,10 @@ class AggFuncInfo(NamedTuple): # If true, the aggregate function is applied to "NULL" rather than a column is_dummy_agg: bool + # If specified, the pandas label to provide the new column generated by this aggregate + # function. Used in conjunction with pd.NamedAgg. + post_agg_pandas_label: Optional[Hashable] = None + def _columns_coalescing_min(*cols: SnowparkColumn) -> Callable: """ @@ -503,6 +520,8 @@ def is_supported_snowflake_agg_func( Returns: is_valid: bool. Whether it is valid to implement with snowflake or not. """ + if isinstance(agg_func, tuple) and len(agg_func) == 2: + agg_func = agg_func[0] return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None @@ -545,7 +564,7 @@ def check_is_aggregation_supported_in_snowflake( return all( ( are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis) - if is_list_like(value) + if is_list_like(value) and not is_named_tuple(value) else is_supported_snowflake_agg_func(value, agg_kwargs, axis) ) for value in agg_func.values() @@ -910,7 +929,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str: def generate_pandas_labels_for_agg_result_columns( pandas_label: Hashable, num_levels: int, - agg_func_list: list[AggFuncTypeBase], + agg_func_list: list[AggFuncInfo], include_agg_func_in_agg_label: bool, include_pandas_label_in_agg_label: bool, ) -> list[Hashable]: @@ -947,17 +966,20 @@ def generate_pandas_labels_for_agg_result_columns( ), "the result aggregation label must at least contain at least the original label or the aggregation function name." agg_func_column_labels = [] for agg_func in agg_func_list: - label_tuple = ( - from_pandas_label(pandas_label, num_levels) - if include_pandas_label_in_agg_label - else () - ) - aggr_func_label = ( - (get_pandas_aggr_func_name(agg_func),) - if include_agg_func_in_agg_label - else () - ) - label_tuple = label_tuple + aggr_func_label + if agg_func.post_agg_pandas_label is None: + label_tuple = ( + from_pandas_label(pandas_label, num_levels) + if include_pandas_label_in_agg_label + else () + ) + aggr_func_label = ( + (get_pandas_aggr_func_name(agg_func.func),) + if include_agg_func_in_agg_label + else () + ) + label_tuple = label_tuple + aggr_func_label + else: + label_tuple = (agg_func.post_agg_pandas_label,) agg_func_column_labels.append(to_pandas_label(label_tuple)) return agg_func_column_labels @@ -1011,8 +1033,13 @@ def generate_column_agg_info( # if any value in the dictionary is a list, the aggregation function name is added as # an extra level to the final pandas label, otherwise not. When any value in the dictionary is a list, # the aggregation function name will be added as an extra level for the result label. + # One exception to this rule is when the user passes in pd.NamedAgg for the aggregations + # instead of using the aggfunc argument. Then, each aggregation (even if on the same column) + # has a unique name, and so we do not need to insert the additional level. agg_func_level_included = any( - is_list_like(fn) and not is_named_tuple(fn) + is_list_like(fn) + and not is_named_tuple(fn) + and not any(f.post_agg_pandas_label is not None for f in fn) for fn in column_to_agg_func.values() ) pandas_label_level_included = ( @@ -1031,7 +1058,7 @@ def generate_column_agg_info( agg_col_labels = generate_pandas_labels_for_agg_result_columns( pandas_label_to_identifier.pandas_label, num_levels, - [func for (func, _) in agg_func_list], + agg_func_list, # type: ignore[arg-type] agg_func_level_included, pandas_label_level_included, ) @@ -1045,7 +1072,8 @@ def generate_column_agg_info( for func_info, label, identifier in zip( agg_func_list, agg_col_labels, agg_col_identifiers ): - (func, is_dummy_agg) = func_info + func = func_info.func + is_dummy_agg = func_info.is_dummy_agg agg_func_col = pandas_lit(None) if is_dummy_agg else quoted_identifier snowflake_agg_func = get_snowflake_agg_func(func, agg_kwargs, axis=0) # once reach here, we require all func have a corresponding snowflake aggregation function. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 5d3ae7c1c14..476e3cc0cd8 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -45,6 +45,7 @@ is_bool_dtype, is_datetime64_any_dtype, is_integer_dtype, + is_named_tuple, is_numeric_dtype, is_re_compilable, is_scalar, @@ -122,6 +123,7 @@ from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, AggFuncInfo, + AggFuncWithLabel, AggregateColumnOpParameters, _columns_coalescing_idxmax_idxmin_helper, aggregate_with_ordered_dataframe, @@ -2566,6 +2568,17 @@ def groupby_agg( dropna=agg_kwargs.get("dropna", True), ) else: + # Named aggregates are passed in via agg_kwargs. We should not pass `agg_func` since we have modified + # it to be of the form {column_name: (agg_func, new_column_name), ...}, which will cause pandas to error out. + if isinstance(agg_func, dict) and all( + is_named_tuple(func) and len(func) == 2 + for func in agg_func.values() + ): + agg_func = ", ".join( + [f"{key}={value}" for key, value in agg_kwargs.items()] + ) + agg_func = f"agg({agg_func})" + ErrorMessage.not_implemented( f"Snowpark pandas GroupBy.{agg_func} does not yet support pd.Grouper, axis == 1, by != None and level != None, by containing any non-pandas hashable labels, or unsupported aggregation parameters." ) @@ -2573,6 +2586,7 @@ def groupby_agg( sort = groupby_kwargs.get("sort", True) as_index = groupby_kwargs.get("as_index", True) dropna = groupby_kwargs.get("dropna", True) + uses_named_aggs = False original_index_column_labels = self._modin_frame.index_column_pandas_labels @@ -2605,11 +2619,27 @@ def groupby_agg( # turn each agg function into an AggFuncInfo named tuple, where is_dummy_agg is set to false; # i.e., none of the aggregations here can be dummy. + def convert_func_to_agg_func_info( + func: Union[AggFuncType, AggFuncWithLabel] + ) -> AggFuncInfo: + nonlocal uses_named_aggs + if is_named_tuple(func): + uses_named_aggs = True + return AggFuncInfo( + func=func.func, + is_dummy_agg=False, + post_agg_pandas_label=func.pandas_label, + ) + else: + return AggFuncInfo( + func=func, is_dummy_agg=False, post_agg_pandas_label=None + ) + column_to_agg_func = { agg_col: ( - [AggFuncInfo(func=fn, is_dummy_agg=False) for fn in func] - if is_list_like(func) - else AggFuncInfo(func=func, is_dummy_agg=False) + [convert_func_to_agg_func_info(fn) for fn in func] + if is_list_like(func) and not is_named_tuple(func) + else convert_func_to_agg_func_info(func) ) for (agg_col, func) in column_to_agg_func.items() } @@ -2684,7 +2714,7 @@ def groupby_agg( internal_frame.index_column_snowflake_quoted_identifiers ) drop = False - if not as_index: + if not as_index and not uses_named_aggs: # drop off the index columns that are from the original index columns and also the index # columns that are from data column with aggregation function applied. # For example: with the following dataframe, which has data column ['A', 'B', 'C', 'D', 'E'] diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 53b511ad69d..4c3e83c7c46 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -23,6 +23,7 @@ assert_frame_equal, assert_snowpark_pandas_equal_to_pandas, assert_snowpark_pandas_equals_to_pandas_with_coerce_to_float64, + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, create_snow_df_with_table_and_data, create_test_dfs, eval_snowpark_pandas_result, @@ -142,6 +143,34 @@ def test_groupby_agg_with_decimal_dtype(session, agg_method) -> None: eval_snowpark_pandas_result(snowpark_pandas_groupby, pandas_groupby, agg_method) +@sql_count_checker(query_count=8) +def test_groupby_agg_with_decimal_dtype_named_agg(session) -> None: + # create table + table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + Utils.create_table( + session, table_name, "COL_G string, COL_D decimal(38, 1)", is_temporary=True + ) + session.sql(f"insert into {table_name} values ('A', 1)").collect() + session.sql(f"insert into {table_name} values ('B', 2)").collect() + session.sql(f"insert into {table_name} values ('A', 3)").collect() + session.sql(f"insert into {table_name} values ('B', 5)").collect() + + snowpark_pandas_df = pd.read_snowflake(table_name) + pandas_df = snowpark_pandas_df.to_pandas() + + by = "COL_G" + with SqlCounter(query_count=1): + snowpark_pandas_groupby = snowpark_pandas_df.groupby(by=by) + pandas_groupby = pandas_df.groupby(by=by) + eval_snowpark_pandas_result( + snowpark_pandas_groupby, + pandas_groupby, + lambda gr: gr.agg( + new_col=pd.NamedAgg("COL_D", "max"), new_col1=("COL_D", np.std) + ), + ) + + @sql_count_checker(query_count=2) def test_groupby_agg_with_float_dtypes(agg_method) -> None: snowpark_pandas_df = pd.DataFrame( @@ -192,6 +221,64 @@ def test_groupby_agg_with_float_dtypes(agg_method) -> None: ) +@sql_count_checker(query_count=2) +def test_groupby_agg_with_float_dtypes_named_agg() -> None: + snowpark_pandas_df = pd.DataFrame( + { + "col1_grp": ["g1", "g2", "g0", "g0", "g2", "g3", "g0", "g2", "g3"], + "col2_float16": np.arange(9, dtype="float16") // 3, + "col3_float64": np.arange(9, dtype="float64") // 4, + "col4_float32": np.arange(9, dtype="float32") // 5, + "col5_mixed": np.concatenate( + [ + np.arange(3, dtype="int64"), + np.arange(3, dtype="float32"), + np.arange(3, dtype="float64"), + ] + ), + "col6_float_identical": [3.0] * 9, + "col7_float_missing": [ + 3.0, + 2.0, + np.nan, + 1.0, + np.nan, + 4.0, + np.nan, + np.nan, + 7.0, + ], + "col8_mix_missing": np.concatenate( + [ + np.arange(2, dtype="int64"), + [np.nan, np.nan], + np.arange(2, dtype="float32"), + [np.nan], + np.arange(2, dtype="float64"), + ] + ), + } + ) + + by = "col1_grp" + snowpark_pandas_groupby, pandas_groupby = eval_groupby_result( + snowpark_pandas_df, by + ) + eval_snowpark_pandas_result( + snowpark_pandas_groupby, + pandas_groupby, + lambda gr: gr.agg( + new_col1=pd.NamedAgg("col2_float16", max), + new_col2=("col3_float64", min), + new_col3=("col4_float32", np.std), + new_col4=("col5_mixed", max), + new_col5=("col6_float_identical", np.std), + new_col6=("col7_float_missing", max), + new_col7=("col8_mix_missing", min), + ), + ) + + @sql_count_checker(query_count=2) def test_groupby_agg_with_int_dtypes(int_to_decimal_float_agg_method) -> None: snowpark_pandas_df = pd.DataFrame( @@ -364,6 +451,43 @@ def test_groupby_agg_on_groupby_columns( ) +@pytest.mark.parametrize( + "by", ["col1", ["col1", "col2", "col3"], ["col1", "col1", "col2"]] +) +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize("sort", [True, False]) +def test_groupby_agg_on_groupby_columns_named_agg( + basic_snowpark_pandas_df, by, as_index, sort +) -> None: + query_count = 2 + kwargs = {} + # https://github.com/pandas-dev/pandas/issues/58446 + # pandas (and Snowpark pandas) fail when duplicate columns are specified for + # by and `as_index` is False and `pd.NamedAgg`s are + # used for aggregation functions, but not when a dictionary + # is passed in. + if by == ["col1", "col1", "col2"] and not as_index: + kwargs = { + "expect_exception": True, + "expect_exception_type": ValueError, + "expect_exception_match": "cannot insert col1, already exists", + "assert_exception_equal": True, + } + query_count = 1 + with SqlCounter(query_count=query_count): + native_pandas = basic_snowpark_pandas_df.to_pandas() + eval_snowpark_pandas_result( + basic_snowpark_pandas_df, + native_pandas, + lambda df: df.groupby(by=by, sort=sort, as_index=as_index).agg( + new_col=pd.NamedAgg("col2", sum), + new_col1=("col4", sum), + new_col3=("col3", "min"), + ), + **kwargs, + ) + + @pytest.mark.parametrize( "agg_func", [ @@ -410,6 +534,33 @@ def test_groupby_dropna_single_index(group_data, dropna, as_index) -> None: ) +@pytest.mark.parametrize( + "group_data", + [ + ["A", "B", "A", "B"], + ["A", np.nan, "A", np.nan], + ["A", np.nan, "A", "B"], + [np.nan, np.nan, np.nan, np.nan], + ], +) +@pytest.mark.parametrize("dropna", [True, False]) +@pytest.mark.parametrize("as_index", [True, False]) +@sql_count_checker(query_count=1) +def test_groupby_dropna_single_index_named_agg(group_data, dropna, as_index) -> None: + pandas_df = native_pd.DataFrame( + {"grp_col": group_data, "value": [123.23, 13.0, 12.3, 1.0]} + ) + snow_df = pd.DataFrame(pandas_df) + + eval_snowpark_pandas_result( + snow_df, + pandas_df, + lambda df: df.groupby(by="grp_col", dropna=dropna, as_index=as_index).agg( + new_col=("value", max) + ), + ) + + @pytest.mark.parametrize( "group_index, expected_index_dropna_false", [ @@ -511,6 +662,29 @@ def test_groupby_with_dropna_random(agg_method, dropna: bool) -> None: ) +@pytest.mark.parametrize("dropna", [True, False]) +@sql_count_checker(query_count=2) +def test_groupby_with_dropna_random_named_agg(dropna: bool) -> None: + snowpark_pandas_df = pd.DataFrame(TEST_DF_DATA["float_nan_data"]) + pandas_df = snowpark_pandas_df.to_pandas() + + by = ["col2"] + + agg_funcs = {} + for i, c in enumerate(snowpark_pandas_df.columns): + if c not in by: + agg_funcs[f"new_col{i}"] = (c, np.std) + + snowpark_pandas_groupby = snowpark_pandas_df.groupby(by=by, dropna=dropna) + pandas_groupby = pandas_df.groupby(by=by, dropna=dropna) + + eval_snowpark_pandas_result( + snowpark_pandas_groupby, + pandas_groupby, + lambda gr: gr.agg(**agg_funcs), + ) + + @pytest.mark.parametrize( "by", ["col1_str", "col2_int", "col3_float", ["col3_float", "col1_str"]] ) @@ -912,3 +1086,11 @@ def test_groupby_agg_on_valid_variant_column(session, test_table_name): } ), ) + + +@sql_count_checker(query_count=2) +def test_valid_func_valid_kwarg_should_work(basic_snowpark_pandas_df): + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + basic_snowpark_pandas_df.groupby("col1").agg(max, min_count=2), + basic_snowpark_pandas_df.to_pandas().groupby("col1").max(min_count=2), + ) diff --git a/tests/integ/modin/groupby/test_groupby_default2pandas.py b/tests/integ/modin/groupby/test_groupby_default2pandas.py index 76e7ebe3c82..3ce87de4364 100644 --- a/tests/integ/modin/groupby/test_groupby_default2pandas.py +++ b/tests/integ/modin/groupby/test_groupby_default2pandas.py @@ -111,6 +111,20 @@ def test_groupby_agg_func_unsupported(basic_snowpark_pandas_df, agg_func, args): basic_snowpark_pandas_df.groupby(by).agg(agg_func, *args) +@pytest.mark.parametrize( + "agg_func", + [ + lambda x: np.sum(x), # callable + np.ptp, # Unsupported aggregation function + ], +) +@sql_count_checker(query_count=0) +def test_groupby_agg_func_unsupported_named_agg(basic_snowpark_pandas_df, agg_func): + by = "col1" + with pytest.raises(NotImplementedError): + basic_snowpark_pandas_df.groupby(by=by).agg(new_col=("col2", agg_func)) + + @pytest.mark.parametrize( "agg_func", [lambda x: x * 2, np.sin, {"col2": "max", "col4": np.sin}], diff --git a/tests/integ/modin/groupby/test_groupby_named_agg.py b/tests/integ/modin/groupby/test_groupby_named_agg.py new file mode 100644 index 00000000000..0b56de391c6 --- /dev/null +++ b/tests/integ/modin/groupby/test_groupby_named_agg.py @@ -0,0 +1,57 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.snowpark.exceptions import SnowparkSQLException +from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.utils import eval_snowpark_pandas_result + + +@sql_count_checker(query_count=1) +def test_invalid_named_agg_errors(basic_snowpark_pandas_df): + eval_snowpark_pandas_result( + basic_snowpark_pandas_df, + basic_snowpark_pandas_df.to_pandas(), + lambda df: df.groupby("col1").agg(args=80, valid_agg=("col2", min)), + expect_exception=True, + expect_exception_match="Must provide 'func' or tuples of '\\(column, aggfunc\\).", + assert_exception_equal=False, # There is a typo in the pandas exception. + expect_exception_type=TypeError, + ) + + +@sql_count_checker(query_count=6) +@pytest.mark.xfail( + reason="SNOW-1336091: Snowpark pandas cannot run in sprocs until modin 0.28.1 is available in conda", + strict=True, + raises=AssertionError, +) +def test_invalid_func_with_named_agg_errors(basic_snowpark_pandas_df): + # This test checks that a SnowparkSQLException is raised by this code, since the + # code is invalid. This code relies on falling back to native pandas though, + # so until SNOW-1336091 is fixed, a RuntimeError will instead by raised by the + # Snowpark pandas code. This test then errors out with an AssertionError, since + # the assertion that the raised exception is a SnowparkSQLException is False, + # so we mark it as xfail with raises=AssertionError. When SNOW-1336091 is fixed, + # this test should pass automatically. + eval_snowpark_pandas_result( + basic_snowpark_pandas_df, + basic_snowpark_pandas_df.to_pandas(), + lambda df: df.groupby("col1").agg(80, valid_agg=("col2", min)), + expect_exception=True, + assert_exception_equal=False, # We fallback and then raise the correct error. + expect_exception_type=SnowparkSQLException, + ) + + +@sql_count_checker(query_count=1) +def test_valid_func_with_named_agg_errors(basic_snowpark_pandas_df): + eval_snowpark_pandas_result( + basic_snowpark_pandas_df, + basic_snowpark_pandas_df.to_pandas(), + lambda df: df.groupby("col1").agg(max, new_col=("col2", min)), + expect_exception=True, + assert_exception_equal=False, # There is a difference in our errors. + expect_exception_type=TypeError, + ) diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py index ae1a9e9d103..e01850b25c7 100644 --- a/tests/integ/modin/groupby/test_groupby_negative.py +++ b/tests/integ/modin/groupby/test_groupby_negative.py @@ -295,6 +295,20 @@ def test_groupby_agg_dict_like_input_invalid_column_raises(basic_snowpark_pandas ) +@sql_count_checker(query_count=1) +def test_groupby_named_agg_like_input_invalid_column_raises(basic_snowpark_pandas_df): + eval_snowpark_pandas_result( + basic_snowpark_pandas_df, + basic_snowpark_pandas_df.to_pandas(), + lambda df: df.groupby(by="col1").aggregate( + new_col=("col2", max), new_col1=("col_invalid", "min") + ), + expect_exception=True, + expect_exception_type=KeyError, + expect_exception_match=re.escape("Column(s) ['col_invalid'] do not exist"), + ) + + @sql_count_checker(query_count=1) def test_groupby_as_index_false_with_dup(basic_snowpark_pandas_df) -> None: by = ["col1", "col1"] diff --git a/tests/integ/modin/groupby/test_groupby_series.py b/tests/integ/modin/groupby/test_groupby_series.py index 2b2b1591705..652855d0ad6 100644 --- a/tests/integ/modin/groupby/test_groupby_series.py +++ b/tests/integ/modin/groupby/test_groupby_series.py @@ -83,6 +83,21 @@ def test_groupby_agg_series(agg_func, sort): ) +@pytest.mark.parametrize("sort", [True, False]) +@pytest.mark.parametrize("aggs", [{"minimum": min}, {"minimum": min, "maximum": max}]) +@sql_count_checker(query_count=2) +def test_groupby_agg_series_named_agg(aggs, sort): + index = native_pd.Index(["a", "b", "b", "a", "c"]) + index.names = ["grp_col"] + series = pd.Series([3.5, 1.2, 4.3, 2.0, 1.8], index=index) + + eval_snowpark_pandas_result( + series, + series.to_pandas(), + lambda se: se.groupby(by="grp_col", sort=sort).agg(**aggs), + ) + + @pytest.mark.parametrize("numeric_only", [False, None]) @sql_count_checker(query_count=2) def test_groupby_series_numeric_only(series_str, numeric_only):