Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-976704 : Adding lead and lag dataframe function #1147

Merged
merged 60 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
6840bfd
changes
sfc-gh-rsureshbabu Nov 23, 2023
266a281
updating changelog
sfc-gh-rsureshbabu Nov 23, 2023
2499a6d
fixing comment
sfc-gh-rsureshbabu Nov 23, 2023
4ba1a64
generalizing default formatter
sfc-gh-rsureshbabu Nov 23, 2023
810bfb8
generalizing default formatter 2
sfc-gh-rsureshbabu Nov 23, 2023
b43ffec
fix comment
sfc-gh-rsureshbabu Nov 23, 2023
adca1de
changes
sfc-gh-rsureshbabu Nov 23, 2023
5a007b2
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
2d574b7
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
91ff442
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
bd0bf2a
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
f66d519
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
faccd9b
refactor
sfc-gh-rsureshbabu Nov 23, 2023
9622651
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
f5b0040
changes
sfc-gh-rsureshbabu Nov 23, 2023
e27f0ea
changes
sfc-gh-rsureshbabu Nov 23, 2023
cc4713a
changes
sfc-gh-rsureshbabu Nov 23, 2023
c5ff590
changes
sfc-gh-rsureshbabu Nov 23, 2023
e8d4345
fix test
sfc-gh-rsureshbabu Nov 23, 2023
3427e27
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
249b384
changes
sfc-gh-rsureshbabu Nov 23, 2023
ccf6655
changes
sfc-gh-rsureshbabu Nov 23, 2023
919e1c8
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
eef9062
changes
sfc-gh-rsureshbabu Nov 23, 2023
8e7adad
changes
sfc-gh-rsureshbabu Nov 23, 2023
1d0b265
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
d3a8b6b
changes
sfc-gh-rsureshbabu Nov 23, 2023
dff89e5
changes
sfc-gh-rsureshbabu Nov 23, 2023
f8e0fc7
changes
sfc-gh-rsureshbabu Nov 23, 2023
9a9a045
changes
sfc-gh-rsureshbabu Nov 23, 2023
3025d15
changes
sfc-gh-rsureshbabu Nov 23, 2023
d849dab
changes
sfc-gh-rsureshbabu Nov 23, 2023
706f15f
changes
sfc-gh-rsureshbabu Nov 23, 2023
12e24cd
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Nov 23, 2023
809e129
tests
sfc-gh-rsureshbabu Nov 23, 2023
39a3e8a
fix comment
sfc-gh-rsureshbabu Dec 18, 2023
b0c2c85
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Dec 18, 2023
48cf80e
changes
sfc-gh-rsureshbabu Dec 18, 2023
2599167
skip tests when pandas are not available
sfc-gh-rsureshbabu Jan 17, 2024
c4416ad
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 17, 2024
ce061bd
update change log
sfc-gh-rsureshbabu Jan 17, 2024
5c00182
changes
sfc-gh-rsureshbabu Jan 17, 2024
e0e18bf
changes
sfc-gh-rsureshbabu Jan 17, 2024
2d05cb8
adding doctest
sfc-gh-rsureshbabu Jan 18, 2024
9423599
renaming
sfc-gh-rsureshbabu Jan 23, 2024
721383f
renaming 2
sfc-gh-rsureshbabu Jan 23, 2024
27df8d2
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 23, 2024
24a016e
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Jan 23, 2024
9bb98a9
move comment
sfc-gh-rsureshbabu Jan 23, 2024
cfb5e4a
fix comments
sfc-gh-rsureshbabu Jan 23, 2024
9583161
fix error message
sfc-gh-rsureshbabu Jan 24, 2024
045f216
fix merge
sfc-gh-rsureshbabu Jan 24, 2024
05a2da5
fix code coverage
sfc-gh-rsureshbabu Jan 24, 2024
ed24d6a
Merge branch 'rsureshbabu-SNOW-SNOW-976701-movingagg' into rsureshbab…
sfc-gh-rsureshbabu Jan 24, 2024
222ec96
add docstest
sfc-gh-rsureshbabu Jan 24, 2024
594e39d
merge from main
sfc-gh-rsureshbabu Jan 25, 2024
5b78ef2
fix tests
sfc-gh-rsureshbabu Jan 25, 2024
af473ff
address review comments
sfc-gh-rsureshbabu Jan 31, 2024
a0a6ef0
Merge branch 'main' into rsureshbabu-SNOW-976704-leadlagfunctions
sfc-gh-rsureshbabu Jan 31, 2024
83a06a9
change to with_columns
sfc-gh-rsureshbabu Jan 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
- `sign`/`signum`
- Added the following functions to `DataFrame.analytics`:
- Added the `moving_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes.
- Added the `cummulative_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes.
- Added the `cummulative_agg` function in `DataFrame.analytics` to enable commulative aggregations like sums and averages on multiple columns.
- Added the `compute_lag` and `compute_lead` function in `DataFrame.analytics` for enabling lead and lag calculations on multiple columns.

### Bug Fixes

Expand Down
150 changes: 148 additions & 2 deletions src/snowflake/snowpark/dataframe_analytics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List
from typing import Callable, Dict, List, Union

import snowflake.snowpark
from snowflake.snowpark.functions import expr
from snowflake.snowpark import Column
from snowflake.snowpark.column import _to_col_if_str
from snowflake.snowpark.functions import expr, lag, lead
from snowflake.snowpark.window import Window


Expand Down Expand Up @@ -85,6 +87,42 @@ def _validate_formatter_argument(self, fromatter):
if not callable(fromatter):
raise TypeError("formatter must be a callable function")

def _compute_window_function(
self,
cols: List[Union[str, Column]],
periods: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str],
window_func: Callable[[Column, int], Column],
func_name: str,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Generic function to create window function columns (lag or lead) for the DataFrame.
Args:
func_name: Should be either "LEAD" or "LAG".
"""
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(periods, func_name.lower() + "s")
self._validate_formatter_argument(col_formatter)

window_spec = Window.partition_by(group_by).order_by(order_by)
df = self._df
col_names = []
values = []
for c in cols:
for period in periods:
column = _to_col_if_str(c, f"transform.compute_{func_name.lower()}")
window_col = window_func(column, period).over(window_spec)
formatted_col_name = col_formatter(
column.get_name().replace('"', ""), func_name, period
)
col_names.append(formatted_col_name)
values.append(window_col)

return df.with_columns(col_names, values)

def moving_agg(
self,
aggs: Dict[str, List[str]],
Expand Down Expand Up @@ -248,3 +286,111 @@ def cumulative_agg(
agg_df = agg_df.with_column(formatted_col_name, agg_col)

return agg_df

def compute_lag(
self,
cols: List[Union[str, Column]],
lags: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Creates lag columns to the specified columns of the DataFrame by grouping and ordering criteria.

Args:
cols: List of column names or Column objects to calculate lag features.
lags: List of positive integers specifying periods to lag by.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>LAG<lag>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for lag value, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified lag period.

Example:
>>> sample_data = [
... ["2023-01-01", 101, 200],
... ["2023-01-02", 101, 100],
... ["2023-01-03", 101, 300],
... ["2023-01-04", 102, 250],
... ]
>>> df = session.create_dataframe(sample_data).to_df(
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
... )
>>> res = df.analytics.compute_lag(
... cols=["SALESAMOUNT"],
... lags=[1, 2],
... order_by=["ORDERDATE"],
... group_by=["PRODUCTKEY"],
... )
>>> res.show()
------------------------------------------------------------------------------------------
|"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LAG_1" |"SALESAMOUNT_LAG_2" |
------------------------------------------------------------------------------------------
|2023-01-04 |102 |250 |NULL |NULL |
|2023-01-01 |101 |200 |NULL |NULL |
|2023-01-02 |101 |100 |200 |NULL |
|2023-01-03 |101 |300 |100 |200 |
------------------------------------------------------------------------------------------
<BLANKLINE>
"""
return self._compute_window_function(
cols, lags, order_by, group_by, col_formatter, lag, "LAG"
)

def compute_lead(
self,
cols: List[Union[str, Column]],
leads: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Creates lead columns to the specified columns of the DataFrame by grouping and ordering criteria.

Args:
cols: List of column names or Column objects to calculate lead features.
leads: List of positive integers specifying periods to lead by.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>LEAD<lead>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the lead value, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified lead period.

Example:
>>> sample_data = [
... ["2023-01-01", 101, 200],
... ["2023-01-02", 101, 100],
... ["2023-01-03", 101, 300],
... ["2023-01-04", 102, 250],
... ]
>>> df = session.create_dataframe(sample_data).to_df(
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
... )
>>> res = df.analytics.compute_lead(
... cols=["SALESAMOUNT"],
... leads=[1, 2],
... order_by=["ORDERDATE"],
... group_by=["PRODUCTKEY"]
... )
>>> res.show()
--------------------------------------------------------------------------------------------
|"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LEAD_1" |"SALESAMOUNT_LEAD_2" |
--------------------------------------------------------------------------------------------
|2023-01-04 |102 |250 |NULL |NULL |
|2023-01-01 |101 |200 |100 |300 |
|2023-01-02 |101 |100 |300 |NULL |
|2023-01-03 |101 |300 |NULL |NULL |
--------------------------------------------------------------------------------------------
<BLANKLINE>
"""
return self._compute_window_function(
cols, leads, order_by, group_by, col_formatter, lead, "LEAD"
)
103 changes: 103 additions & 0 deletions tests/integ/test_df_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytest

from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import col


def get_sample_dataframe(session):
Expand Down Expand Up @@ -310,3 +311,105 @@ def custom_formatter(input_col, agg):
check_dtype=False,
atol=1e-1,
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_compute_lead(session):
"""Tests df.analytics.compute_lead() happy path."""

df = get_sample_dataframe(session)

def custom_col_formatter(input_col, op, lead):
return f"{op}_{input_col}_{lead}"

res = df.analytics.compute_lead(
cols=["SALESAMOUNT"],
leads=[1, 2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=custom_col_formatter,
)

res = res.withColumn("LEAD_SALESAMOUNT_1", col("LEAD_SALESAMOUNT_1").cast("float"))
res = res.withColumn("LEAD_SALESAMOUNT_2", col("LEAD_SALESAMOUNT_2").cast("float"))

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"LEAD_SALESAMOUNT_1": [100, 300, None, None],
"LEAD_SALESAMOUNT_2": [300, None, None, None],
}
expected_df = pd.DataFrame(expected_data)

assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_compute_lag(session):
"""Tests df.analytics.compute_lag() happy path."""

df = get_sample_dataframe(session)

def custom_col_formatter(input_col, op, lead):
return f"{op}_{input_col}_{lead}"

res = df.analytics.compute_lag(
cols=["SALESAMOUNT"],
lags=[1, 2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=custom_col_formatter,
)

res = res.withColumn("LAG_SALESAMOUNT_1", col("LAG_SALESAMOUNT_1").cast("float"))
res = res.withColumn("LAG_SALESAMOUNT_2", col("LAG_SALESAMOUNT_2").cast("float"))

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"LAG_SALESAMOUNT_1": [None, 200, 100, None],
"LAG_SALESAMOUNT_2": [None, None, 200, None],
}
expected_df = pd.DataFrame(expected_data)

assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_lead_lag_invalid_inputs(session):
"""Tests df.analytics.compute_lag() and df.analytics.compute_lead() with invalid_inputs."""

df = get_sample_dataframe(session)

with pytest.raises(ValueError) as exc:
df.analytics.compute_lead(
cols=["SALESAMOUNT"],
leads=[-1, -2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "leads must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.analytics.compute_lead(
cols=["SALESAMOUNT"],
leads=[0, 2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "leads must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.analytics.compute_lag(
cols=["SALESAMOUNT"],
lags=[-1, -2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "lags must be a list of integers > 0" in str(exc)
Loading