Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-902943]: Add support for pd.NamedAgg in groupby.agg #1432

Merged
merged 16 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def aggregate(
obj=self,
allow_duplication=False,
axis=axis,
**kwargs,
)

# This is to stay consistent with pandas result format, when the func is single
Expand Down
17 changes: 12 additions & 5 deletions src/snowflake/snowpark/modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
# Snowpark pandas API version
from snowflake.snowpark.modin.pandas.series import Series
from snowflake.snowpark.modin.pandas.utils import (
extract_validate_and_try_convert_named_aggs_from_kwargs,
raise_if_native_pandas_objects,
validate_and_try_convert_agg_func_arg_func_to_str,
)
Expand Down Expand Up @@ -569,10 +570,17 @@ def aggregate(
ErrorMessage.not_implemented(
"axis other than 0 is not supported"
) # pragma: no cover

func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func, obj=self, allow_duplication=True, axis=self._axis
)
if func is None:
func = extract_validate_and_try_convert_named_aggs_from_kwargs(
obj=self, allow_duplication=True, axis=self._axis, **kwargs
)
else:
func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func,
obj=self,
allow_duplication=True,
axis=self._axis,
)

if isinstance(func, str):
# Using "getattr" here masks possible AttributeError which we throw
Expand Down Expand Up @@ -1151,7 +1159,6 @@ def aggregate(
raise SpecificationError(
"Value for func argument in dict format is not allowed for SeriesGroupBy."
)

sfc-gh-mvashishtha marked this conversation as resolved.
Show resolved Hide resolved
return super().aggregate(
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)
Expand Down
86 changes: 80 additions & 6 deletions src/snowflake/snowpark/modin/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
FactoryDispatcher,
)
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
AggFuncWithLabel,
get_pandas_aggr_func_name,
)
from snowflake.snowpark.modin.plugin.compiler import BaseQueryCompiler
Expand Down Expand Up @@ -539,6 +540,85 @@ def _try_convert_single_builtin_func_to_str(f):
return _try_convert_single_builtin_func_to_str(fn)


def extract_validate_and_try_convert_named_aggs_from_kwargs(
obj: object, allow_duplication: bool, axis: int, **kwargs
) -> AggFuncType:
"""
Attempt to extract pd.NamedAgg (or tuples of the same format) from the kwargs.

kwargs: dict
The kwargs to extract from.

Returns:
A dictionary mapping columns to a tuple containing the aggregation to perform, as well
as the pandas label to give the aggregated column.
"""
from snowflake.snowpark.modin.pandas.groupby import SeriesGroupBy

named_aggs = {}
accepted_keys = []
columns = obj._query_compiler.columns
for key, value in kwargs.items():
if isinstance(value, pd.NamedAgg) or (
isinstance(value, tuple) and len(value) == 2
):
if axis == 0:
# If axis == 1, we would need a query to materialize the index to check its existence
# so we defer the error checking to later.
if value[0] not in columns:
raise KeyError(f"Column(s) ['{value[0]}'] do not exist")

if value[0] in named_aggs:
if not isinstance(named_aggs[value[0]], list):
named_aggs[value[0]] = [named_aggs[value[0]]]
named_aggs[value[0]] += [
AggFuncWithLabel(func=value[1], pandas_label=key)
]
else:
named_aggs[value[0]] = AggFuncWithLabel(func=value[1], pandas_label=key)
accepted_keys += [key]
elif isinstance(obj, SeriesGroupBy):
col_name = obj._df._query_compiler.columns[0]
if col_name not in named_aggs:
named_aggs[col_name] = AggFuncWithLabel(func=value, pandas_label=key)
else:
if not isinstance(named_aggs[col_name], list):
named_aggs[col_name] = [named_aggs[col_name]]
named_aggs[col_name] += [AggFuncWithLabel(func=value, pandas_label=key)]
accepted_keys += [key]

if len(named_aggs.keys()) == 0:
ErrorMessage.not_implemented(
"Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas"
)

if any(key not in accepted_keys for key in kwargs.keys()):
# For compatibility with pandas errors. Otherwise, we would just ignore
# those kwargs.
raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).")

validated_named_aggs = {}
for key, value in named_aggs.items():
if isinstance(value, list):
validated_named_aggs[key] = [
AggFuncWithLabel(
func=validate_and_try_convert_agg_func_arg_func_to_str(
v.func, obj, allow_duplication, axis
),
pandas_label=v.pandas_label,
)
for v in value
]
else:
validated_named_aggs[key] = AggFuncWithLabel(
func=validate_and_try_convert_agg_func_arg_func_to_str(
value.func, obj, allow_duplication, axis
),
pandas_label=value.pandas_label,
)
return validated_named_aggs


def validate_and_try_convert_agg_func_arg_func_to_str(
agg_func: AggFuncType, obj: object, allow_duplication: bool, axis: int
) -> AggFuncType:
Expand Down Expand Up @@ -580,12 +660,6 @@ def validate_and_try_convert_agg_func_arg_func_to_str(

"""
if agg_func is None:
# Snowpark pandas only support func argument at this moment.
# TODO (SNOW-902943): pandas allows usage of NamedAgg in kwargs to configure
# tuples of (columns, agg_func) with rename. For example:
# df.groupby('A').agg(b_min=pd.NamedAgg(column='B', aggfunc='min')), which applies
# min function on column 'B', and uses 'b_min' as the new column name.
# Once supported, refine the check to check both.
ErrorMessage.not_implemented(
"Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,19 @@ def _columns_coalescing_idxmax_idxmin_helper(
}


class AggFuncWithLabel(NamedTuple):
"""
This class is used to process NamedAgg's internally, and represents an AggFunc that
also includes a label to be used on the column that it generates.
"""

# The aggregate function
func: AggFuncTypeBase

# The label to provide the new column produced by `func`.
pandas_label: Hashable


class AggFuncInfo(NamedTuple):
"""
Information needed to distinguish between dummy and normal aggregate functions.
Expand All @@ -278,6 +291,10 @@ class AggFuncInfo(NamedTuple):
# If true, the aggregate function is applied to "NULL" rather than a column
is_dummy_agg: bool

# If specified, the pandas label to provide the new column generated by this aggregate
# function. Used in conjunction with pd.NamedAgg.
post_agg_pandas_label: Optional[Hashable] = None


def _columns_coalescing_min(*cols: SnowparkColumn) -> Callable:
"""
Expand Down Expand Up @@ -503,6 +520,8 @@ def is_supported_snowflake_agg_func(
Returns:
is_valid: bool. Whether it is valid to implement with snowflake or not.
"""
if isinstance(agg_func, tuple) and len(agg_func) == 2:
agg_func = agg_func[0]
return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None


Expand Down Expand Up @@ -545,7 +564,7 @@ def check_is_aggregation_supported_in_snowflake(
return all(
(
are_all_agg_funcs_supported_by_snowflake(value, agg_kwargs, axis)
if is_list_like(value)
if is_list_like(value) and not is_named_tuple(value)
else is_supported_snowflake_agg_func(value, agg_kwargs, axis)
)
for value in agg_func.values()
Expand Down Expand Up @@ -910,7 +929,7 @@ def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str:
def generate_pandas_labels_for_agg_result_columns(
pandas_label: Hashable,
num_levels: int,
agg_func_list: list[AggFuncTypeBase],
agg_func_list: list[AggFuncInfo],
include_agg_func_in_agg_label: bool,
include_pandas_label_in_agg_label: bool,
) -> list[Hashable]:
Expand Down Expand Up @@ -947,17 +966,20 @@ def generate_pandas_labels_for_agg_result_columns(
), "the result aggregation label must at least contain at least the original label or the aggregation function name."
agg_func_column_labels = []
for agg_func in agg_func_list:
label_tuple = (
from_pandas_label(pandas_label, num_levels)
if include_pandas_label_in_agg_label
else ()
)
aggr_func_label = (
(get_pandas_aggr_func_name(agg_func),)
if include_agg_func_in_agg_label
else ()
)
label_tuple = label_tuple + aggr_func_label
if agg_func.post_agg_pandas_label is None:
label_tuple = (
from_pandas_label(pandas_label, num_levels)
if include_pandas_label_in_agg_label
else ()
)
aggr_func_label = (
(get_pandas_aggr_func_name(agg_func.func),)
if include_agg_func_in_agg_label
else ()
)
label_tuple = label_tuple + aggr_func_label
else:
label_tuple = (agg_func.post_agg_pandas_label,)
agg_func_column_labels.append(to_pandas_label(label_tuple))

return agg_func_column_labels
Expand Down Expand Up @@ -1011,8 +1033,13 @@ def generate_column_agg_info(
# if any value in the dictionary is a list, the aggregation function name is added as
# an extra level to the final pandas label, otherwise not. When any value in the dictionary is a list,
# the aggregation function name will be added as an extra level for the result label.
# One exception to this rule is when the user passes in pd.NamedAgg for the aggregations
# instead of using the aggfunc argument. Then, each aggregation (even if on the same column)
# has a unique name, and so we do not need to insert the additional level.
agg_func_level_included = any(
is_list_like(fn) and not is_named_tuple(fn)
is_list_like(fn)
and not is_named_tuple(fn)
and not any(f.post_agg_pandas_label is not None for f in fn)
for fn in column_to_agg_func.values()
)
pandas_label_level_included = (
Expand All @@ -1031,7 +1058,7 @@ def generate_column_agg_info(
agg_col_labels = generate_pandas_labels_for_agg_result_columns(
pandas_label_to_identifier.pandas_label,
num_levels,
[func for (func, _) in agg_func_list],
agg_func_list,
agg_func_level_included,
pandas_label_level_included,
)
Expand All @@ -1045,7 +1072,8 @@ def generate_column_agg_info(
for func_info, label, identifier in zip(
agg_func_list, agg_col_labels, agg_col_identifiers
):
(func, is_dummy_agg) = func_info
func = func_info.func
is_dummy_agg = func_info.is_dummy_agg
agg_func_col = pandas_lit(None) if is_dummy_agg else quoted_identifier
snowflake_agg_func = get_snowflake_agg_func(func, agg_kwargs, axis=0)
# once reach here, we require all func have a corresponding snowflake aggregation function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_bool_dtype,
is_datetime64_any_dtype,
is_integer_dtype,
is_named_tuple,
is_numeric_dtype,
is_re_compilable,
is_scalar,
Expand Down Expand Up @@ -2632,10 +2633,18 @@ def groupby_agg(
) and check_is_aggregation_supported_in_snowflake(agg_func, agg_kwargs, axis)

def register_default_to_pandas() -> SnowflakeQueryCompiler:
# Named aggregates are passed in via agg_kwargs. We should not pass `agg_func` since we have modified
# it to be of the form {column_name: (agg_func, new_column_name), ...}, which will cause pandas to error out.
if isinstance(agg_func, dict) and all(
is_named_tuple(func) and len(func) == 2 for func in agg_func.values()
):
func = None
else:
func = agg_func
return GroupByDefault.register(GroupByDefault.get_aggregation_method(how))(
self,
by=by,
agg_func=agg_func,
agg_func=func,
axis=axis,
groupby_kwargs=groupby_kwargs,
agg_args=agg_args,
Expand All @@ -2659,6 +2668,7 @@ def register_default_to_pandas() -> SnowflakeQueryCompiler:
sort = groupby_kwargs.get("sort", True)
as_index = groupby_kwargs.get("as_index", True)
dropna = groupby_kwargs.get("dropna", True)
uses_named_aggs = False

original_index_column_labels = self._modin_frame.index_column_pandas_labels

Expand Down Expand Up @@ -2689,11 +2699,25 @@ def register_default_to_pandas() -> SnowflakeQueryCompiler:

# turn each agg function into an AggFuncInfo named tuple, where is_dummy_agg is set to false;
# i.e., none of the aggregations here can be dummy.
def convert_func_to_agg_func_info(func):
nonlocal uses_named_aggs
if is_named_tuple(func):
uses_named_aggs = True
return AggFuncInfo(
func=func.func,
is_dummy_agg=False,
post_agg_pandas_label=func.pandas_label,
)
else:
return AggFuncInfo(
func=func, is_dummy_agg=False, post_agg_pandas_label=None
)

column_to_agg_func = {
agg_col: (
[AggFuncInfo(func=fn, is_dummy_agg=False) for fn in func]
if is_list_like(func)
else AggFuncInfo(func=func, is_dummy_agg=False)
[convert_func_to_agg_func_info(fn) for fn in func]
if is_list_like(func) and not is_named_tuple(func)
else convert_func_to_agg_func_info(func)
)
for (agg_col, func) in column_to_agg_func.items()
}
Expand Down Expand Up @@ -2768,7 +2792,7 @@ def register_default_to_pandas() -> SnowflakeQueryCompiler:
internal_frame.index_column_snowflake_quoted_identifiers
)
drop = False
if not as_index:
if not as_index and not uses_named_aggs:
# drop off the index columns that are from the original index columns and also the index
# columns that are from data column with aggregation function applied.
# For example: with the following dataframe, which has data column ['A', 'B', 'C', 'D', 'E']
Expand Down
Loading
Loading