Skip to content

Commit

Permalink
SNOW-1830534: Add decoder logic for the dataframe analytics functions (
Browse files Browse the repository at this point in the history
…#2803)

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   Fixes SNOW-1830534

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.

Added decoder logic for the dataframe analytics functions. I "hardcoded"
the col formatter lambda used since this is not necessary in the actual
decoder but it helps pass the tests. It should not be hard to refactor
out/remove.
  • Loading branch information
sfc-gh-vbudati authored Dec 21, 2024
1 parent ea58fb7 commit d31ec43
Showing 1 changed file with 139 additions and 5 deletions.
144 changes: 139 additions & 5 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
#

import logging
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple
import re
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal

import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto

from google.protobuf.json_format import MessageToDict

from snowflake.snowpark import Session, Column
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions
import snowflake.snowpark.functions
from snowflake.snowpark.functions import udf, when
from snowflake.snowpark.types import (
Expand Down Expand Up @@ -73,6 +74,38 @@ def capture_local_variable_name(self, assign_expr: proto.Assign) -> str:
"""
return assign_expr.symbol.value

def get_dataframe_analytics_function_column_formatter(
self, sp_dataframe_analytics_expr: proto.Expr
) -> Callable:
"""
Create a dataframe analytics function column formatter.
This is mainly to pass the df_analytics_functions.test.
Parameters
----------
sp_dataframe_analytics_expr : proto.Expr
The dataframe analytics expression.
Returns
-------
Callable
The dataframe analytics function column formatter.
"""
if "formattedColNames" in MessageToDict(sp_dataframe_analytics_expr):
formatted_col_names = list(sp_dataframe_analytics_expr.formatted_col_names)
w_lambda_pattern = re.compile(r"^(\w+)_W_(\w+)$")
xy_lambda_pattern = re.compile(r"^(\w+)_X_(\w+)_Y_(\w+)$")
if all(re.match(xy_lambda_pattern, col) for col in formatted_col_names):
return (
lambda input, agg, window_size: f"{input}_X_{agg}_Y_{window_size}"
)
elif all(re.match(w_lambda_pattern, col) for col in formatted_col_names):
return lambda input, agg: f"{input}_W_{agg}"
else:
return lambda input_col, agg, window: f"{agg}_{input_col}_{window}"
else:
return DataFrameAnalyticsFunctions._default_col_formatter

def decode_col_exprs(self, expr: proto.Expr, is_variadic: bool) -> List[Column]:
"""
Decode a protobuf object to a list of column expressions.
Expand Down Expand Up @@ -286,7 +319,7 @@ def decode_dataframe_schema_expr(
case _:
raise ValueError(
"Unknown dataframe schema type: %s"
% df_schema_expr.WhichOneof("variant")
% df_schema_expr.WhichOneof("sealed_value")
)

def decode_data_type_expr(
Expand Down Expand Up @@ -856,8 +889,11 @@ def decode_expr(self, expr: proto.Expr) -> Any:
# DATAFRAME FUNCTIONS
case "sp_create_dataframe":
data = self.decode_dataframe_data_expr(expr.sp_create_dataframe.data)
schema = self.decode_dataframe_schema_expr(
expr.sp_create_dataframe.schema
d = MessageToDict(expr.sp_create_dataframe)
schema = (
self.decode_dataframe_schema_expr(expr.sp_create_dataframe.schema)
if "schema" in d
else None
)
df = self.session.create_dataframe(data=data, schema=schema)
if hasattr(expr, "var_id"):
Expand All @@ -882,6 +918,96 @@ def decode_expr(self, expr: proto.Expr) -> Any:
name = expr.sp_dataframe_alias.name
return df.alias(name)

case "sp_dataframe_analytics_compute_lag":
df = self.decode_expr(expr.sp_dataframe_analytics_compute_lag.df)
cols = [
self.decode_expr(col)
for col in expr.sp_dataframe_analytics_compute_lag.cols
]
group_by = list(expr.sp_dataframe_analytics_compute_lag.group_by)
lags = list(expr.sp_dataframe_analytics_compute_lag.lags)
order_by = list(expr.sp_dataframe_analytics_compute_lag.order_by)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_compute_lag
)
return df.analytics.compute_lag(
cols, lags, order_by, group_by, col_formatter
)

case "sp_dataframe_analytics_compute_lead":
df = self.decode_expr(expr.sp_dataframe_analytics_compute_lead.df)
cols = [
self.decode_expr(col)
for col in expr.sp_dataframe_analytics_compute_lead.cols
]
group_by = list(expr.sp_dataframe_analytics_compute_lead.group_by)
leads = list(expr.sp_dataframe_analytics_compute_lead.leads)
order_by = list(expr.sp_dataframe_analytics_compute_lead.order_by)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_compute_lead
)
return df.analytics.compute_lead(
cols, leads, order_by, group_by, col_formatter
)

case "sp_dataframe_analytics_cumulative_agg":
df = self.decode_expr(expr.sp_dataframe_analytics_cumulative_agg.df)
gen_aggs = self.decode_dsl_map_expr(
expr.sp_dataframe_analytics_cumulative_agg.aggs
)
# The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings.
aggs = {str(k): list(v) for k, v in gen_aggs.items()}
group_by = list(expr.sp_dataframe_analytics_cumulative_agg.group_by)
order_by = list(expr.sp_dataframe_analytics_cumulative_agg.order_by)
is_forward = (
expr.sp_dataframe_analytics_cumulative_agg.is_forward
if hasattr(expr.sp_dataframe_analytics_cumulative_agg, "is_forward")
else False
)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_cumulative_agg
)
return df.analytics.cumulative_agg(
aggs, group_by, order_by, is_forward, col_formatter
)

case "sp_dataframe_analytics_moving_agg":
df = self.decode_expr(expr.sp_dataframe_analytics_moving_agg.df)
gen_aggs = self.decode_dsl_map_expr(
expr.sp_dataframe_analytics_moving_agg.aggs
)
# The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings.
aggs = {str(k): list(v) for k, v in gen_aggs.items()}
group_by = list(expr.sp_dataframe_analytics_moving_agg.group_by)
order_by = list(expr.sp_dataframe_analytics_moving_agg.order_by)
window_sizes = list(expr.sp_dataframe_analytics_moving_agg.window_sizes)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_moving_agg
)
return df.analytics.moving_agg(
aggs, window_sizes, order_by, group_by, col_formatter
)

case "sp_dataframe_analytics_time_series_agg":
df = self.decode_expr(expr.sp_dataframe_analytics_time_series_agg.df)
gen_aggs = self.decode_dsl_map_expr(
expr.sp_dataframe_analytics_time_series_agg.aggs
)
# The aggs dict created has generator objects as the kv pairs. Convert them to strings/list of strings.
aggs = {str(k): list(v) for k, v in gen_aggs.items()}
group_by = list(expr.sp_dataframe_analytics_time_series_agg.group_by)
sliding_interval = (
expr.sp_dataframe_analytics_time_series_agg.sliding_interval
)
time_col = expr.sp_dataframe_analytics_time_series_agg.time_col
windows = list(expr.sp_dataframe_analytics_time_series_agg.windows)
col_formatter = self.get_dataframe_analytics_function_column_formatter(
expr.sp_dataframe_analytics_time_series_agg
)
return df.analytics.time_series_agg(
time_col, aggs, windows, group_by, sliding_interval, col_formatter
)

case "sp_dataframe_col":
col_name = expr.sp_dataframe_col.col_name
df = self.decode_expr(expr.sp_dataframe_col.df)
Expand Down Expand Up @@ -1072,6 +1198,14 @@ def decode_expr(self, expr: proto.Expr) -> Any:
else:
return df.sort(cols, ascending)

case "sp_dataframe_to_df":
df = self.decode_expr(expr.sp_dataframe_to_df.df)
col_names = list(expr.sp_dataframe_to_df.col_names)
if expr.sp_dataframe_to_df.variadic:
return df.to_df(*col_names)
else:
return df.to_df(col_names)

case "sp_dataframe_unpivot":
df = self.decode_expr(expr.sp_dataframe_unpivot.df)
column_list = [
Expand Down

0 comments on commit d31ec43

Please sign in to comment.