Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
frederiksteiner authored Feb 5, 2024
2 parents d4c6115 + 24e6d79 commit 898b098
Show file tree
Hide file tree
Showing 7 changed files with 580 additions and 15 deletions.
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

## 1.13.0 (TBD)

### Behavior Changes (API Compatible)

### New Features
- Added support for an optional `date_part` argument in function `last_day`

## 1.12.1 (TBD)
Expand All @@ -13,6 +12,8 @@
### Bug Fixes

- Fixed a bug in `DataFrame.to_pandas` that caused an error when evaluating on a dataframe with an IntergerType column with null values.
- Fixed a bug in `DataFrame.to_local_iterator` where the iterator could yield wrong results if another query is executed before the iterator finishes due to wrong isolation level. For details, please see #945.
- Fixed a bug that truncated table names in error messages while running a plan with local testing enabled.

## 1.12.0 (2024-01-30)

Expand All @@ -33,7 +34,8 @@
- Added the following functions to `DataFrame.analytics`:
- Added the `moving_agg` function in `DataFrame.analytics` to enable moving aggregations like sums and averages with multiple window sizes.
- Added the `cummulative_agg` function in `DataFrame.analytics` to enable commulative aggregations like sums and averages on multiple columns.
- Added the `compute_lag` and `compute_lead` function in `DataFrame.analytics` for enabling lead and lag calculations on multiple columns.
- Added the `compute_lag` and `compute_lead` functions in `DataFrame.analytics` for enabling lead and lag calculations on multiple columns.
- Added the `time_series_agg` function in `DataFrame.analytics` to enable time series aggregations like sums and averages with multiple time windows.

### Bug Fixes

Expand Down
10 changes: 9 additions & 1 deletion src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,16 @@ def _to_data_or_iter(
results_cursor: SnowflakeCursor,
to_pandas: bool = False,
to_iter: bool = False,
num_statements: Optional[int] = None,
) -> Dict[str, Any]:
if (
to_iter and not to_pandas
): # Fix for SNOW-869536, to_pandas doesn't have this issue, SnowflakeCursor.fetch_pandas_batches already handles the isolation.
new_cursor = results_cursor.connection.cursor()
new_cursor.execute(
f"SELECT * FROM TABLE(RESULT_SCAN('{results_cursor.sfqid}'))"
)
results_cursor = new_cursor

if to_pandas:
try:
data_or_iter = (
Expand Down
290 changes: 286 additions & 4 deletions src/snowflake/snowpark/dataframe_analytics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,30 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List, Union
from typing import Callable, Dict, List, Tuple, Union

import snowflake.snowpark
from snowflake.snowpark import Column
from snowflake.snowpark.column import _to_col_if_str
from snowflake.snowpark.functions import expr, lag, lead
from snowflake.snowpark._internal.utils import experimental
from snowflake.snowpark.column import Column, _to_col_if_str
from snowflake.snowpark.functions import (
add_months,
col,
dateadd,
expr,
from_unixtime,
lag,
lead,
lit,
months_between,
to_timestamp,
unix_timestamp,
year,
)
from snowflake.snowpark.window import Window

# "s" (seconds), "m" (minutes), "h" (hours), "d" (days), "w" (weeks), "mm" (months), "y" (years)
SUPPORTED_TIME_UNITS = ["s", "m", "h", "d", "w", "mm", "y"]


class DataFrameAnalyticsFunctions:
"""Provides data analytics functions for DataFrames.
Expand Down Expand Up @@ -123,6 +139,131 @@ def _compute_window_function(

return df.with_columns(col_names, values)

def _parse_time_string(self, time_str: str) -> Tuple[int, str]:
index = len(time_str)
for i, char in enumerate(time_str):
if not char.isdigit() and char not in ["+", "-"]:
index = i
break

duration = int(time_str[:index])
unit = time_str[index:].lower()

return duration, unit

def _validate_and_extract_time_unit(
self, time_str: str, argument_name: str, allow_negative: bool = True
) -> Tuple[int, str]:
argument_requirements = (
f"The '{argument_name}' argument must adhere to the following criteria: "
"1) It must not be an empty string. "
"2) The last character must be a supported time unit. "
f"Supported units are '{', '.join(SUPPORTED_TIME_UNITS)}'. "
"3) The preceding characters must represent an integer. "
"4) The integer must not be negative if allow_negative is False."
)
if not time_str:
raise ValueError(
f"{argument_name} must not be empty. {argument_requirements}"
)

duration, unit = self._parse_time_string(time_str)

if not allow_negative and duration < 0:
raise ValueError(
f"{argument_name} must not be negative. {argument_requirements}"
)

if unit not in SUPPORTED_TIME_UNITS:
raise ValueError(
f"Unsupported unit '{unit}'. Supported units are '{SUPPORTED_TIME_UNITS}. {argument_requirements}"
)

return duration, unit

def _get_sliding_interval_start(
self, time_col: Column, unit: str, duration: int
) -> Column:
unit_seconds = {
"s": 1, # seconds
"m": 60, # minutes
"h": 3600, # hours
"d": 86400, # days
"w": 604800, # weeks
}

if unit == "mm":
base_date = lit("1970-01-01").cast("date")
months_since_base = months_between(time_col, base_date)
current_window_start_month = (months_since_base / duration).cast(
"long"
) * duration
return to_timestamp(add_months(base_date, current_window_start_month))

elif unit == "y":
base_date = lit("1970-01-01").cast("date")
years_since_base = year(time_col) - year(base_date)
current_window_start_year = (years_since_base / duration).cast(
"long"
) * duration
return to_timestamp(add_months(base_date, current_window_start_year * 12))

elif unit in unit_seconds:
# Handle seconds, minutes, hours, days, weeks
interval_seconds = unit_seconds[unit] * duration
return from_unixtime(
(unix_timestamp(time_col) / interval_seconds).cast("long")
* interval_seconds
)

else:
raise ValueError(
"Invalid unit. Supported units are 'S', 'M', 'H', 'D', 'W', 'MM', 'Y'."
)

def _perform_window_aggregations(
self,
base_df: "snowflake.snowpark.dataframe.DataFrame",
input_df: "snowflake.snowpark.dataframe.DataFrame",
aggs: Dict[str, List[str]],
group_by_cols: List[str],
col_formatter: Callable[[str, str, str], str] = None,
window: str = None,
rename_suffix: str = "",
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Perform window-based aggregations on the given DataFrame.
This function applies specified aggregation functions to columns of an input DataFrame,
grouped by specified columns, and joins the results back to a base DataFrame.
Parameters:
- base_df: DataFrame to which the aggregated results will be joined.
- input_df: DataFrame on which aggregations are to be performed.
- aggs: A dictionary where keys are column names and values are lists of aggregation functions.
- group_by_cols: List of column names to group by.
- col_formatter: Optional callable to format column names of aggregated results.
- window: Optional window specification for aggregations.
- rename_suffix: Optional suffix to append to column names.
Returns:
- DataFrame with the aggregated data joined to the base DataFrame.
"""
for column, funcs in aggs.items():
for func in funcs:
agg_column_name = (
col_formatter(column, func, window)
if col_formatter
else f"{column}_{func}{rename_suffix}"
)
agg_expression = expr(f"{func}({column}{rename_suffix})").alias(
agg_column_name
)
agg_df = input_df.group_by(group_by_cols).agg(agg_expression)
base_df = base_df.join(agg_df, on=group_by_cols, how="left")

return base_df

def moving_agg(
self,
aggs: Dict[str, List[str]],
Expand Down Expand Up @@ -394,3 +535,144 @@ def compute_lead(
return self._compute_window_function(
cols, leads, order_by, group_by, col_formatter, lead, "LEAD"
)

@experimental(version="1.12.0")
def time_series_agg(
self,
time_col: str,
aggs: Dict[str, List[str]],
windows: List[str],
group_by: List[str],
sliding_interval: str,
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Applies aggregations to the specified columns of the DataFrame over specified time windows,
and grouping criteria.
Args:
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions.
windows: Time windows for aggregations using strings such as '7D' for 7 days, where the units are
S: Seconds, M: Minutes, H: Hours, D: Days, W: Weeks, MM: Months, Y: Years. For future-oriented analysis, use positive numbers,
and for past-oriented analysis, use negative numbers.
sliding_interval: Interval at which the window slides, specified in the same format as the windows.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the window size, and returns a formatted string for the column name.
Returns:
A Snowflake DataFrame with additional columns corresponding to each specified time window aggregation.
Raises:
ValueError: If an unsupported value is specified in arguments.
TypeError: If an unsupported type is specified in arguments.
SnowparkSQLException: If an unsupported aggregration is specified.
Example:
>>> sample_data = [
... ["2023-01-01", 101, 200],
... ["2023-01-02", 101, 100],
... ["2023-01-03", 101, 300],
... ["2023-01-04", 102, 250],
... ]
>>> df = session.create_dataframe(sample_data).to_df(
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
... )
>>> df = df.with_column("ORDERDATE", to_timestamp(df["ORDERDATE"]))
>>> def custom_formatter(input_col, agg, window):
... return f"{agg}_{input_col}_{window}"
>>> res = df.analytics.time_series_agg(
... time_col="ORDERDATE",
... group_by=["PRODUCTKEY"],
... aggs={"SALESAMOUNT": ["SUM", "MAX"]},
... windows=["1D", "-1D"],
... sliding_interval="12H",
... col_formatter=custom_formatter,
... )
>>> res.show()
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"PRODUCTKEY" |"SLIDING_POINT" |"SALESAMOUNT" |"ORDERDATE" |"SUM_SALESAMOUNT_1D" |"MAX_SALESAMOUNT_1D" |"SUM_SALESAMOUNT_-1D" |"MAX_SALESAMOUNT_-1D" |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|101 |2023-01-01 00:00:00 |200 |2023-01-01 00:00:00 |300 |200 |200 |200 |
|101 |2023-01-02 00:00:00 |100 |2023-01-02 00:00:00 |400 |300 |300 |200 |
|101 |2023-01-03 00:00:00 |300 |2023-01-03 00:00:00 |300 |300 |400 |300 |
|102 |2023-01-04 00:00:00 |250 |2023-01-04 00:00:00 |250 |250 |250 |250 |
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
<BLANKLINE>
"""
self._validate_aggs_argument(aggs)
self._validate_string_list_argument(group_by, "group_by")
self._validate_formatter_argument(col_formatter)

if not windows:
raise ValueError("windows must not be empty")

if not sliding_interval:
raise ValueError("sliding_interval must not be empty")

if not time_col or not isinstance(time_col, str):
raise ValueError("time_col must be a string")

slide_duration, slide_unit = self._validate_and_extract_time_unit(
sliding_interval, "sliding_interval", allow_negative=False
)
sliding_point_col = "sliding_point"

agg_df = self._df
agg_df = agg_df.with_column(
sliding_point_col,
self._get_sliding_interval_start(time_col, slide_unit, slide_duration),
)

# Perform aggregations at sliding interval granularity.
group_by_cols = group_by + [sliding_point_col]
sliding_windows_df = self._perform_window_aggregations(
agg_df, agg_df, aggs, group_by_cols
)

# Perform aggregations at window intervals.
result_df = agg_df
for window in windows:
window_duration, window_unit = self._validate_and_extract_time_unit(
window, "window"
)
# Perform self-join on DataFrame for aggregation within each group and time window.
left_df = sliding_windows_df.alias("A")
right_df = sliding_windows_df.alias("B")

for column in right_df.columns:
if column not in group_by:
right_df = right_df.with_column_renamed(column, f"{column}B")

self_joined_df = left_df.join(right_df, on=group_by, how="leftouter")

window_frame = dateadd(
window_unit, lit(window_duration), f"{sliding_point_col}"
)

if window_duration > 0: # Future window
window_start = col(f"{sliding_point_col}")
window_end = window_frame
else: # Past window
window_start = window_frame
window_end = col(f"{sliding_point_col}")

# Filter rows to include only those within the specified time window for aggregation.
self_joined_df = self_joined_df.filter(
col(f"{sliding_point_col}B") >= window_start
).filter(col(f"{sliding_point_col}B") <= window_end)

# Peform final aggregations.
group_by_cols = group_by + [sliding_point_col]
result_df = self._perform_window_aggregations(
result_df,
self_joined_df,
aggs,
group_by_cols,
col_formatter,
window,
rename_suffix="B",
)

return result_df
10 changes: 7 additions & 3 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3244,7 +3244,7 @@ def hour(e: ColumnOrName) -> Column:
return builtin("hour")(c)


def last_day(expr: ColumnOrName, part: Optional[ColumnOrName] = "MONTH") -> Column:
def last_day(expr: ColumnOrName, part: Optional[ColumnOrName] = None) -> Column:
"""
Returns the last day of the specified date part for a date or timestamp.
Commonly used to return the last day of the month for a date or timestamp.
Expand All @@ -3262,12 +3262,16 @@ def last_day(expr: ColumnOrName, part: Optional[ColumnOrName] = "MONTH") -> Colu
... datetime.datetime.strptime("2020-08-21 01:30:05.000", "%Y-%m-%d %H:%M:%S.%f")
... ], schema=["a"])
>>> df.select(last_day("a")).collect()
[Row(LAST_DAY("A", "MONTH")=datetime.date(2020, 5, 31)), Row(LAST_DAY("A", "MONTH")=datetime.date(2020, 8, 31))]
[Row(LAST_DAY("A")=datetime.date(2020, 5, 31)), Row(LAST_DAY("A")=datetime.date(2020, 8, 31))]
>>> df.select(last_day("a", "YEAR")).collect()
[Row(LAST_DAY("A", "YEAR")=datetime.date(2020, 12, 31)), Row(LAST_DAY("A", "YEAR")=datetime.date(2020, 12, 31))]
"""
part_col = _to_col_if_str(part, "last_day")
expr_col = _to_col_if_str(expr, "last_day")
if part is None:
# Ensure we do not change the column name
return builtin("last_day")(expr_col)

part_col = _to_col_if_str(part, "last_day")
return builtin("last_day")(expr_col, part_col)


Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,9 @@ def execute_mock_plan(
return res_df
else:
db_schme_table = parse_table_name(entity_name)
table = ".".join([part.strip("\"'") for part in db_schme_table[:3]])
raise SnowparkSQLException(
f"Object '{db_schme_table[0][1:-1]}.{db_schme_table[1][1:-1]}.{db_schme_table[2][1:-1]}' does not exist or not authorized."
f"Object '{table}' does not exist or not authorized."
)
if isinstance(source_plan, Aggregate):
child_rf = execute_mock_plan(source_plan.child)
Expand Down
Loading

0 comments on commit 898b098

Please sign in to comment.