Skip to content

Commit

Permalink
SNOW-976704 : Adding lead and lag dataframe function (#1147)
Browse files Browse the repository at this point in the history
The PR is adding the lead and lag functions as proposed in https://docs.google.com/document/d/14J5lr_a3fE1xeU-YzKWdeduBdC_035sry11lJUjjkT8/edit
  • Loading branch information
sfc-gh-rsureshbabu authored Jan 31, 2024
1 parent 42c525d commit 3cc1e70
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
- `sign`/`signum`
- Added the following functions to `DataFrame.analytics`:
- Added the `moving_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes.
- Added the `cummulative_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes.
- Added the `cummulative_agg` function in `DataFrame.analytics` to enable commulative aggregations like sums and averages on multiple columns.
- Added the `compute_lag` and `compute_lead` function in `DataFrame.analytics` for enabling lead and lag calculations on multiple columns.

### Bug Fixes

Expand Down
150 changes: 148 additions & 2 deletions src/snowflake/snowpark/dataframe_analytics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List
from typing import Callable, Dict, List, Union

import snowflake.snowpark
from snowflake.snowpark.functions import expr
from snowflake.snowpark import Column
from snowflake.snowpark.column import _to_col_if_str
from snowflake.snowpark.functions import expr, lag, lead
from snowflake.snowpark.window import Window


Expand Down Expand Up @@ -85,6 +87,42 @@ def _validate_formatter_argument(self, fromatter):
if not callable(fromatter):
raise TypeError("formatter must be a callable function")

def _compute_window_function(
self,
cols: List[Union[str, Column]],
periods: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str],
window_func: Callable[[Column, int], Column],
func_name: str,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Generic function to create window function columns (lag or lead) for the DataFrame.
Args:
func_name: Should be either "LEAD" or "LAG".
"""
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(periods, func_name.lower() + "s")
self._validate_formatter_argument(col_formatter)

window_spec = Window.partition_by(group_by).order_by(order_by)
df = self._df
col_names = []
values = []
for c in cols:
for period in periods:
column = _to_col_if_str(c, f"transform.compute_{func_name.lower()}")
window_col = window_func(column, period).over(window_spec)
formatted_col_name = col_formatter(
column.get_name().replace('"', ""), func_name, period
)
col_names.append(formatted_col_name)
values.append(window_col)

return df.with_columns(col_names, values)

def moving_agg(
self,
aggs: Dict[str, List[str]],
Expand Down Expand Up @@ -248,3 +286,111 @@ def cumulative_agg(
agg_df = agg_df.with_column(formatted_col_name, agg_col)

return agg_df

def compute_lag(
self,
cols: List[Union[str, Column]],
lags: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Creates lag columns to the specified columns of the DataFrame by grouping and ordering criteria.
Args:
cols: List of column names or Column objects to calculate lag features.
lags: List of positive integers specifying periods to lag by.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>LAG<lag>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for lag value, and returns a formatted string for the column name.
Returns:
A Snowflake DataFrame with additional columns corresponding to each specified lag period.
Example:
>>> sample_data = [
... ["2023-01-01", 101, 200],
... ["2023-01-02", 101, 100],
... ["2023-01-03", 101, 300],
... ["2023-01-04", 102, 250],
... ]
>>> df = session.create_dataframe(sample_data).to_df(
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
... )
>>> res = df.analytics.compute_lag(
... cols=["SALESAMOUNT"],
... lags=[1, 2],
... order_by=["ORDERDATE"],
... group_by=["PRODUCTKEY"],
... )
>>> res.show()
------------------------------------------------------------------------------------------
|"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LAG_1" |"SALESAMOUNT_LAG_2" |
------------------------------------------------------------------------------------------
|2023-01-04 |102 |250 |NULL |NULL |
|2023-01-01 |101 |200 |NULL |NULL |
|2023-01-02 |101 |100 |200 |NULL |
|2023-01-03 |101 |300 |100 |200 |
------------------------------------------------------------------------------------------
<BLANKLINE>
"""
return self._compute_window_function(
cols, lags, order_by, group_by, col_formatter, lag, "LAG"
)

def compute_lead(
self,
cols: List[Union[str, Column]],
leads: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Creates lead columns to the specified columns of the DataFrame by grouping and ordering criteria.
Args:
cols: List of column names or Column objects to calculate lead features.
leads: List of positive integers specifying periods to lead by.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>LEAD<lead>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the lead value, and returns a formatted string for the column name.
Returns:
A Snowflake DataFrame with additional columns corresponding to each specified lead period.
Example:
>>> sample_data = [
... ["2023-01-01", 101, 200],
... ["2023-01-02", 101, 100],
... ["2023-01-03", 101, 300],
... ["2023-01-04", 102, 250],
... ]
>>> df = session.create_dataframe(sample_data).to_df(
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
... )
>>> res = df.analytics.compute_lead(
... cols=["SALESAMOUNT"],
... leads=[1, 2],
... order_by=["ORDERDATE"],
... group_by=["PRODUCTKEY"]
... )
>>> res.show()
--------------------------------------------------------------------------------------------
|"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LEAD_1" |"SALESAMOUNT_LEAD_2" |
--------------------------------------------------------------------------------------------
|2023-01-04 |102 |250 |NULL |NULL |
|2023-01-01 |101 |200 |100 |300 |
|2023-01-02 |101 |100 |300 |NULL |
|2023-01-03 |101 |300 |NULL |NULL |
--------------------------------------------------------------------------------------------
<BLANKLINE>
"""
return self._compute_window_function(
cols, leads, order_by, group_by, col_formatter, lead, "LEAD"
)
103 changes: 103 additions & 0 deletions tests/integ/test_df_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytest

from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import col


def get_sample_dataframe(session):
Expand Down Expand Up @@ -310,3 +311,105 @@ def custom_formatter(input_col, agg):
check_dtype=False,
atol=1e-1,
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_compute_lead(session):
"""Tests df.analytics.compute_lead() happy path."""

df = get_sample_dataframe(session)

def custom_col_formatter(input_col, op, lead):
return f"{op}_{input_col}_{lead}"

res = df.analytics.compute_lead(
cols=["SALESAMOUNT"],
leads=[1, 2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=custom_col_formatter,
)

res = res.withColumn("LEAD_SALESAMOUNT_1", col("LEAD_SALESAMOUNT_1").cast("float"))
res = res.withColumn("LEAD_SALESAMOUNT_2", col("LEAD_SALESAMOUNT_2").cast("float"))

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"LEAD_SALESAMOUNT_1": [100, 300, None, None],
"LEAD_SALESAMOUNT_2": [300, None, None, None],
}
expected_df = pd.DataFrame(expected_data)

assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_compute_lag(session):
"""Tests df.analytics.compute_lag() happy path."""

df = get_sample_dataframe(session)

def custom_col_formatter(input_col, op, lead):
return f"{op}_{input_col}_{lead}"

res = df.analytics.compute_lag(
cols=["SALESAMOUNT"],
lags=[1, 2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=custom_col_formatter,
)

res = res.withColumn("LAG_SALESAMOUNT_1", col("LAG_SALESAMOUNT_1").cast("float"))
res = res.withColumn("LAG_SALESAMOUNT_2", col("LAG_SALESAMOUNT_2").cast("float"))

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"LAG_SALESAMOUNT_1": [None, 200, 100, None],
"LAG_SALESAMOUNT_2": [None, None, 200, None],
}
expected_df = pd.DataFrame(expected_data)

assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_lead_lag_invalid_inputs(session):
"""Tests df.analytics.compute_lag() and df.analytics.compute_lead() with invalid_inputs."""

df = get_sample_dataframe(session)

with pytest.raises(ValueError) as exc:
df.analytics.compute_lead(
cols=["SALESAMOUNT"],
leads=[-1, -2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "leads must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.analytics.compute_lead(
cols=["SALESAMOUNT"],
leads=[0, 2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "leads must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.analytics.compute_lag(
cols=["SALESAMOUNT"],
lags=[-1, -2],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "lags must be a list of integers > 0" in str(exc)

0 comments on commit 3cc1e70

Please sign in to comment.