Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-976704 : Adding lead and lag dataframe function #1147

Merged
merged 60 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
6840bfd
changes
sfc-gh-rsureshbabu Nov 23, 2023
266a281
updating changelog
sfc-gh-rsureshbabu Nov 23, 2023
2499a6d
fixing comment
sfc-gh-rsureshbabu Nov 23, 2023
4ba1a64
generalizing default formatter
sfc-gh-rsureshbabu Nov 23, 2023
810bfb8
generalizing default formatter 2
sfc-gh-rsureshbabu Nov 23, 2023
b43ffec
fix comment
sfc-gh-rsureshbabu Nov 23, 2023
adca1de
changes
sfc-gh-rsureshbabu Nov 23, 2023
5a007b2
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
2d574b7
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
91ff442
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
bd0bf2a
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
f66d519
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
faccd9b
refactor
sfc-gh-rsureshbabu Nov 23, 2023
9622651
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
f5b0040
changes
sfc-gh-rsureshbabu Nov 23, 2023
e27f0ea
changes
sfc-gh-rsureshbabu Nov 23, 2023
cc4713a
changes
sfc-gh-rsureshbabu Nov 23, 2023
c5ff590
changes
sfc-gh-rsureshbabu Nov 23, 2023
e8d4345
fix test
sfc-gh-rsureshbabu Nov 23, 2023
3427e27
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
249b384
changes
sfc-gh-rsureshbabu Nov 23, 2023
ccf6655
changes
sfc-gh-rsureshbabu Nov 23, 2023
919e1c8
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
eef9062
changes
sfc-gh-rsureshbabu Nov 23, 2023
8e7adad
changes
sfc-gh-rsureshbabu Nov 23, 2023
1d0b265
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
d3a8b6b
changes
sfc-gh-rsureshbabu Nov 23, 2023
dff89e5
changes
sfc-gh-rsureshbabu Nov 23, 2023
f8e0fc7
changes
sfc-gh-rsureshbabu Nov 23, 2023
9a9a045
changes
sfc-gh-rsureshbabu Nov 23, 2023
3025d15
changes
sfc-gh-rsureshbabu Nov 23, 2023
d849dab
changes
sfc-gh-rsureshbabu Nov 23, 2023
706f15f
changes
sfc-gh-rsureshbabu Nov 23, 2023
12e24cd
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
809e129
tests
sfc-gh-rsureshbabu Nov 23, 2023
39a3e8a
fix comment
sfc-gh-rsureshbabu Dec 18, 2023
b0c2c85
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Dec 18, 2023
48cf80e
changes
sfc-gh-rsureshbabu Dec 18, 2023
2599167
skip tests when pandas are not available
sfc-gh-rsureshbabu Jan 17, 2024
c4416ad
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 17, 2024
ce061bd
update change log
sfc-gh-rsureshbabu Jan 17, 2024
5c00182
changes
sfc-gh-rsureshbabu Jan 17, 2024
e0e18bf
changes
sfc-gh-rsureshbabu Jan 17, 2024
2d05cb8
adding doctest
sfc-gh-rsureshbabu Jan 18, 2024
9423599
renaming
sfc-gh-rsureshbabu Jan 23, 2024
721383f
renaming 2
sfc-gh-rsureshbabu Jan 23, 2024
27df8d2
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 23, 2024
24a016e
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Jan 23, 2024
9bb98a9
move comment
sfc-gh-rsureshbabu Jan 23, 2024
cfb5e4a
fix comments
sfc-gh-rsureshbabu Jan 23, 2024
9583161
fix error message
sfc-gh-rsureshbabu Jan 24, 2024
045f216
fix merge
sfc-gh-rsureshbabu Jan 24, 2024
05a2da5
fix code coverage
sfc-gh-rsureshbabu Jan 24, 2024
ed24d6a
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Jan 24, 2024
222ec96
add docstest
sfc-gh-rsureshbabu Jan 24, 2024
594e39d
merge from main
sfc-gh-rsureshbabu Jan 25, 2024
5b78ef2
fix tests
sfc-gh-rsureshbabu Jan 25, 2024
af473ff
address review comments
sfc-gh-rsureshbabu Jan 31, 2024
a0a6ef0
Merge branch 'main' into rsureshbabu-SNOW-976704-leadlagfunctions
sfc-gh-rsureshbabu Jan 31, 2024
83a06a9
change to with_columns
sfc-gh-rsureshbabu Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

- Add the `conn_error` attribute to `SnowflakeSQLException` that stores the whole underlying exception from `snowflake-connector-python`

### New Features

- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes.

### Bug Fixes

- DataFrame column names qouting check now supports newline characters.
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"GetResult",
"DataFrame",
"DataFrameStatFunctions",
"DataFrameTransformFunctions",
"DataFrameNaFunctions",
"DataFrameWriter",
"DataFrameReader",
Expand Down Expand Up @@ -49,6 +50,7 @@
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_reader import DataFrameReader
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
from snowflake.snowpark.file_operation import FileOperation, GetResult, PutResult
from snowflake.snowpark.query_history import QueryHistory, QueryRecord
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
from snowflake.snowpark.exceptions import SnowparkDataframeException
from snowflake.snowpark.functions import (
Expand Down Expand Up @@ -521,6 +522,7 @@ def __init__(
self._writer = DataFrameWriter(self)

self._stat = DataFrameStatFunctions(self)
self._transform = DataFrameTransformFunctions(self)
self.approxQuantile = self.approx_quantile = self._stat.approx_quantile
self.corr = self._stat.corr
self.cov = self._stat.cov
Expand All @@ -538,6 +540,10 @@ def __init__(
def stat(self) -> DataFrameStatFunctions:
return self._stat

@property
def transform(self) -> DataFrameTransformFunctions:
return self._transform

@overload
def collect(
self,
Expand Down
209 changes: 209 additions & 0 deletions src/snowflake/snowpark/dataframe_transform_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List, Union

import snowflake.snowpark
from snowflake.snowpark import Column
from snowflake.snowpark.column import _to_col_if_str
from snowflake.snowpark.functions import expr, lag, lead
from snowflake.snowpark.window import Window


class DataFrameTransformFunctions:
"""Provides data transformation functions for DataFrames.
To access an object of this class, use :attr:`DataFrame.transform`.
"""

def __init__(self, df: "snowflake.snowpark.DataFrame") -> None:
self._df = df

def _default_col_formatter(input_col: str, operation: str, *args) -> str:
args_str = "_".join(map(str, args))
formatted_name = f"{input_col}_{operation}"
if args_str:
formatted_name += f"_{args_str}"
return formatted_name

def _validate_aggs_argument(self, aggs):
if not isinstance(aggs, dict):
raise TypeError("aggs must be a dictionary")
if not aggs:
raise ValueError("aggs must not be empty")
if not all(
isinstance(key, str) and isinstance(val, list) and val
for key, val in aggs.items()
):
raise ValueError(
"aggs must have strings as keys and non-empty lists of strings as values"
)

def _validate_string_list_argument(self, data, argument_name):
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list")
if not data:
raise ValueError(f"{argument_name} must not be empty")
if not all(isinstance(item, str) for item in data):
raise ValueError(f"{argument_name} must be a list of strings")

def _validate_positive_integer_list_argument(self, data, argument_name):
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list")
if not data:
raise ValueError(f"{argument_name} must not be empty")
if not all(isinstance(item, int) and item > 0 for item in data):
raise ValueError(f"{argument_name} must be a list of integers > 0")

def _validate_formatter_argument(self, fromatter):
if not callable(fromatter):
raise TypeError("formatter must be a callable function")

def _compute_window_function(
self,
cols: List[Union[str, Column]],
periods: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str],
window_func: Callable[[Column, int], Column],
func_name: str, # Either "LAG" or "LEAD"
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Generic function to create window function columns (lag or lead) for the DataFrame.
"""
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(periods, func_name.lower() + "s")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: why do you need append s on function name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument name is "lags" for the lag function and "leads" for the lead function.

self._validate_formatter_argument(col_formatter)

window_spec = Window.partition_by(group_by).order_by(order_by)
df = self._df
for c in cols:
for period in periods:
column = _to_col_if_str(c, f"transform.compute_{func_name.lower()}")
window_col = window_func(column, period).over(window_spec)
formatted_col_name = col_formatter(
column.get_name().replace('"', ""), func_name, period
)
df = df.with_column(formatted_col_name, window_col)

return df

def moving_agg(
self,
aggs: Dict[str, List[str]],
window_sizes: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Applies moving aggregations to the specified columns of the DataFrame using defined window sizes,
and grouping and ordering criteria.

Args:
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions.
window_sizes: A list of positive integers, each representing the size of the window for which to
calculate the moving aggregate.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the window size, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation.

Raises:
ValueError: If an unsupported value is specified in arguments.
TypeError: If an unsupported type is specified in arguments.
SnowparkSQLException: If an unsupported aggregration is specified.

Example:
aggregated_df = moving_agg(
aggs={"SALESAMOUNT": ['SUM', 'AVG']},
window_sizes=[1, 2, 3, 7],
order_by=['ORDERDATE'],
group_by=['PRODUCTKEY']
)
"""
# Validate input arguments
self._validate_aggs_argument(aggs)
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(window_sizes, "window_sizes")
self._validate_formatter_argument(col_formatter)

# Perform window aggregation
agg_df = self._df
for column, agg_funcs in aggs.items():
for window_size in window_sizes:
for agg_func in agg_funcs:
window_spec = (
Window.partition_by(group_by)
.order_by(order_by)
.rows_between(-window_size + 1, 0)
)

# Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions.
agg_col = expr(f"{agg_func}({column})").over(window_spec)

formatted_col_name = col_formatter(column, agg_func, window_size)
agg_df = agg_df.with_column(formatted_col_name, agg_col)

return agg_df

def compute_lag(
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
self,
cols: List[Union[str, Column]],
lags: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Creates lag columns to the specified columns of the DataFrame by grouping and ordering criteria.

Args:
cols: List of column names or Column objects to calculate lag features.
lags: List of positive integers specifying periods to lag by.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>LAG<lag>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for lag value, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified lag period.
"""
return self._compute_window_function(
cols, lags, order_by, group_by, col_formatter, lag, "LAG"
)

def compute_lead(
self,
cols: List[Union[str, Column]],
leads: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Creates lead columns to the specified columns of the DataFrame by grouping and ordering criteria.

Args:
cols: List of column names or Column objects to calculate lead features.
leads: List of positive integers specifying periods to lead by.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>LEAD<lead>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the lead value, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified lead period.
"""
return self._compute_window_function(
cols, leads, order_by, group_by, col_formatter, lead, "LEAD"
)
Loading
Loading