Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-976702 : Add timeseries agg dataframe function. #1181

Merged
merged 47 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
6840bfd
changes
sfc-gh-rsureshbabu Nov 23, 2023
266a281
updating changelog
sfc-gh-rsureshbabu Nov 23, 2023
2499a6d
fixing comment
sfc-gh-rsureshbabu Nov 23, 2023
4ba1a64
generalizing default formatter
sfc-gh-rsureshbabu Nov 23, 2023
810bfb8
generalizing default formatter 2
sfc-gh-rsureshbabu Nov 23, 2023
b43ffec
fix comment
sfc-gh-rsureshbabu Nov 23, 2023
5a007b2
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
91ff442
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
faccd9b
refactor
sfc-gh-rsureshbabu Nov 23, 2023
e27f0ea
changes
sfc-gh-rsureshbabu Nov 23, 2023
e8d4345
fix test
sfc-gh-rsureshbabu Nov 23, 2023
ccf6655
changes
sfc-gh-rsureshbabu Nov 23, 2023
8e7adad
changes
sfc-gh-rsureshbabu Nov 23, 2023
9a9a045
changes
sfc-gh-rsureshbabu Nov 23, 2023
706f15f
changes
sfc-gh-rsureshbabu Nov 23, 2023
4488351
working changes
sfc-gh-rsureshbabu Dec 2, 2023
833215a
working changes
sfc-gh-rsureshbabu Dec 12, 2023
e22eaf9
working code
sfc-gh-rsureshbabu Dec 13, 2023
8cf6bc8
fix empty lines
sfc-gh-rsureshbabu Dec 13, 2023
39a3e8a
fix comment
sfc-gh-rsureshbabu Dec 18, 2023
b0c2c85
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Dec 18, 2023
c8bd223
adding seconds and minutes support
sfc-gh-rsureshbabu Dec 19, 2023
6b1a7e3
more tests
sfc-gh-rsureshbabu Dec 20, 2023
2599167
skip tests when pandas are not available
sfc-gh-rsureshbabu Jan 17, 2024
c4416ad
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 17, 2024
ce061bd
update change log
sfc-gh-rsureshbabu Jan 17, 2024
5c00182
changes
sfc-gh-rsureshbabu Jan 17, 2024
e0e18bf
changes
sfc-gh-rsureshbabu Jan 17, 2024
2d05cb8
adding doctest
sfc-gh-rsureshbabu Jan 18, 2024
9423599
renaming
sfc-gh-rsureshbabu Jan 23, 2024
721383f
renaming 2
sfc-gh-rsureshbabu Jan 23, 2024
27df8d2
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 23, 2024
cfb5e4a
fix comments
sfc-gh-rsureshbabu Jan 23, 2024
9583161
fix error message
sfc-gh-rsureshbabu Jan 24, 2024
045f216
fix merge
sfc-gh-rsureshbabu Jan 24, 2024
05a2da5
fix code coverage
sfc-gh-rsureshbabu Jan 24, 2024
f765e7b
merge fix
sfc-gh-rsureshbabu Jan 24, 2024
ff58802
merge from main
sfc-gh-rsureshbabu Jan 26, 2024
f8fb88b
fix docstring
sfc-gh-rsureshbabu Jan 26, 2024
d6b029b
address review comments
sfc-gh-rsureshbabu Jan 26, 2024
0b74d98
remove old file
sfc-gh-rsureshbabu Jan 26, 2024
39e1a84
switch to mm instead of t
sfc-gh-rsureshbabu Jan 26, 2024
8bafa95
fix comment
sfc-gh-rsureshbabu Jan 29, 2024
192ea4e
marking API as experimental
sfc-gh-rsureshbabu Jan 30, 2024
cb66dc2
address review comments
sfc-gh-rsureshbabu Jan 31, 2024
8ca0bd2
merge from main
sfc-gh-rsureshbabu Jan 31, 2024
31c06e5
update changelog
sfc-gh-rsureshbabu Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

- Add the `conn_error` attribute to `SnowflakeSQLException` that stores the whole underlying exception from `snowflake-connector-python`

### New Features

- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes.

### Bug Fixes

- DataFrame column names qouting check now supports newline characters.
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"GetResult",
"DataFrame",
"DataFrameStatFunctions",
"DataFrameTransformFunctions",
"DataFrameNaFunctions",
"DataFrameWriter",
"DataFrameReader",
Expand Down Expand Up @@ -49,6 +50,7 @@
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_reader import DataFrameReader
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
from snowflake.snowpark.file_operation import FileOperation, GetResult, PutResult
from snowflake.snowpark.query_history import QueryHistory, QueryRecord
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
from snowflake.snowpark.exceptions import SnowparkDataframeException
from snowflake.snowpark.functions import (
Expand Down Expand Up @@ -521,6 +522,7 @@ def __init__(
self._writer = DataFrameWriter(self)

self._stat = DataFrameStatFunctions(self)
self._transform = DataFrameTransformFunctions(self)
self.approxQuantile = self.approx_quantile = self._stat.approx_quantile
self.corr = self._stat.corr
self.cov = self._stat.cov
Expand All @@ -538,6 +540,10 @@ def __init__(
def stat(self) -> DataFrameStatFunctions:
return self._stat

@property
def transform(self) -> DataFrameTransformFunctions:
return self._transform

@overload
def collect(
self,
Expand Down
298 changes: 298 additions & 0 deletions src/snowflake/snowpark/dataframe_transform_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,298 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List, Tuple

import snowflake.snowpark
from snowflake.snowpark.functions import (
col,
dateadd,
expr,
from_unixtime,
lit,
unix_timestamp,
)
from snowflake.snowpark.window import Window


class DataFrameTransformFunctions:
"""Provides data transformation functions for DataFrames.
To access an object of this class, use :attr:`DataFrame.transform`.
"""

def __init__(self, df: "snowflake.snowpark.DataFrame") -> None:
self._df = df

def _default_col_formatter(input_col: str, operation: str, *args) -> str:
args_str = "_".join(map(str, args))
formatted_name = f"{input_col}_{operation}"
if args_str:
formatted_name += f"_{args_str}"
return formatted_name

def _validate_aggs_argument(self, aggs):
if not isinstance(aggs, dict):
raise TypeError("aggs must be a dictionary")
if not aggs:
raise ValueError("aggs must not be empty")
if not all(
isinstance(key, str) and isinstance(val, list) and val
for key, val in aggs.items()
):
raise ValueError(
"aggs must have strings as keys and non-empty lists of strings as values"
)

def _validate_string_list_argument(self, data, argument_name):
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list")
if not data:
raise ValueError(f"{argument_name} must not be empty")
if not all(isinstance(item, str) for item in data):
raise ValueError(f"{argument_name} must be a list of strings")

def _validate_positive_integer_list_argument(self, data, argument_name):
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list")
if not data:
raise ValueError(f"{argument_name} must not be empty")
if not all(isinstance(item, int) and item > 0 for item in data):
raise ValueError(f"{argument_name} must be a list of integers > 0")

def _validate_formatter_argument(self, fromatter):
if not callable(fromatter):
raise TypeError("formatter must be a callable function")

def _validate_and_extract_time_unit(
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
self, time_str, argument_name, allow_negative=True
) -> Tuple[int, str]:
if not time_str:
raise ValueError(f"{argument_name} must not be empty")

duration = int(time_str[:-1])
unit = time_str[-1].lower()

if not allow_negative and duration < 0:
raise ValueError(f"{argument_name} must not be negative.")

supported_units = ["h", "d", "w", "m", "y"]
if unit not in supported_units:
raise ValueError(
f"Unsupported unit '{unit}'. Supported units are '{supported_units}"
)

# Converting month unit to 'mm' for Snowpark
if unit == "m":
unit = "mm"

return duration, unit

def _get_sliding_interval_start(self, time_col, unit, duration):
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
unit_seconds = {"h": 3600, "d": 86400, "w": 604800}

if unit not in unit_seconds:
raise ValueError("Invalid unit. Supported units are 'H', 'D', 'W'.")

interval_seconds = unit_seconds[unit] * duration

return from_unixtime(
(unix_timestamp(time_col) / interval_seconds).cast("long")
* interval_seconds
)

def moving_agg(
self,
aggs: Dict[str, List[str]],
window_sizes: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Applies moving aggregations to the specified columns of the DataFrame using defined window sizes,
and grouping and ordering criteria.

Args:
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions.
window_sizes: A list of positive integers, each representing the size of the window for which to
calculate the moving aggregate.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the window size, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation.

Raises:
ValueError: If an unsupported value is specified in arguments.
TypeError: If an unsupported type is specified in arguments.
SnowparkSQLException: If an unsupported aggregration is specified.

Example:
aggregated_df = moving_agg(
aggs={"SALESAMOUNT": ['SUM', 'AVG']},
window_sizes=[1, 2, 3, 7],
order_by=['ORDERDATE'],
group_by=['PRODUCTKEY']
)
"""
# Validate input arguments
self._validate_aggs_argument(aggs)
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(window_sizes, "window_sizes")
self._validate_formatter_argument(col_formatter)

# Perform window aggregation
agg_df = self._df
for column, agg_funcs in aggs.items():
for window_size in window_sizes:
for agg_func in agg_funcs:
window_spec = (
Window.partition_by(group_by)
.order_by(order_by)
.rows_between(-window_size + 1, 0)
)

# Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions.
agg_col = expr(f"{agg_func}({column})").over(window_spec)

formatted_col_name = col_formatter(column, agg_func, window_size)
agg_df = agg_df.with_column(formatted_col_name, agg_col)

return agg_df

def time_series_agg(
self,
time_col: str,
aggs: Dict[str, List[str]],
windows: List[str],
group_by: List[str],
sliding_interval: str,
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Applies aggregations to the specified columns of the DataFrame over specified time windows,
and grouping criteria.

Args:
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions.
windows: Time windows for aggregations using strings such as '7D' for 7 days, where the units are
H: Hours, D: Days, W: Weeks, M: Months, Y: Years. For future-oriented analysis, use positive numbers,
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
and for past-oriented analysis, use negative numbers.
sliding_interval: Interval at which the window slides, specified in the same format as the windows.
H: Hours, D: Days, W: Weeks.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the window size, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified time window aggregation.

Raises:
ValueError: If an unsupported value is specified in arguments.
TypeError: If an unsupported type is specified in arguments.
SnowparkSQLException: If an unsupported aggregration is specified.

Example:
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
aggregated_df = df.transform.time_series_agg(
time_col='ORDERTIME',
group_by=['PRODUCTKEY'],
aggs={
'SALESAMOUNT': ['SUM', 'MIN', 'MAX']
},
sliding_interval='12H',
windows=['7D', '14D', '-7D', '-14D', '1T']
)
"""
self._validate_aggs_argument(aggs)
self._validate_string_list_argument(group_by, "group_by")
self._validate_formatter_argument(col_formatter)

if not windows:
raise ValueError("windows must not be empty")

if not sliding_interval:
raise ValueError("sliding_interval must not be empty")

if not time_col or not isinstance(time_col, str):
raise ValueError("time_col must be a string")

# check for valid time_col names

slide_duration, slide_unit = self._validate_and_extract_time_unit(
sliding_interval, "sliding_interval", allow_negative=False
)
sliding_point_col = "sliding_point"
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved

agg_df = self._df
agg_df = agg_df.withColumn(
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
sliding_point_col,
self._get_sliding_interval_start(time_col, slide_unit, slide_duration),
)
agg_df.show()
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
agg_exprs = []
for column, functions in aggs.items():
for function in functions:
agg_exprs.append((column, function))

# Perform the aggregation
sliding_windows_df = agg_df.groupBy(group_by + [sliding_point_col]).agg(
agg_exprs
)

for window in windows:
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
window_duration, window_unit = self._validate_and_extract_time_unit(
window, "window"
)

# Perform self-join on DataFrame for aggregation within each group and time window.
self_joined_df = sliding_windows_df.alias("A").join(
sliding_windows_df.alias("B"), on=group_by, how="leftouter"
)

window_frame = dateadd(
window_unit, lit(window_duration), f"{sliding_point_col}A"
)

if window_duration > 0: # Future window
window_start = col(f"{sliding_point_col}A")
window_end = window_frame
else: # Past window
window_start = window_frame
window_end = col(f"{sliding_point_col}A")

# Filter rows to include only those within the specified time window for aggregation.
self_joined_df = self_joined_df.filter(
col(f"{sliding_point_col}B") >= window_start
).filter(col(f"{sliding_point_col}B") <= window_end)

# Perform aggregations as specified in 'aggs'.
for agg_col, funcs in aggs.items():
for func in funcs:
output_column_name = col_formatter(agg_col, func, window)
input_column_name = f"{agg_col}_{func}"

agg_column_df = self_joined_df.withColumnRenamed(
f"{func}({agg_col})B", input_column_name
)

agg_expr = expr(f"{func}({input_column_name})").alias(
output_column_name
)
agg_column_df = agg_column_df.groupBy(
group_by + [f"{sliding_point_col}A"]
).agg(agg_expr)

agg_column_df = agg_column_df.withColumnRenamed(
f"{sliding_point_col}A", time_col
)

agg_df = agg_df.join(
agg_column_df, on=group_by + [time_col], how="left"
)
return agg_df
Loading
Loading