Skip to content

Commit

Permalink
SNOW-1830517, SNOW-1830516 Add Decoder functionality for more Datafra…
Browse files Browse the repository at this point in the history
…me functions (#2806)

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   Fixes SNOW-1830517, SNOW-1830516

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.
Fixed the way columns were being decoded, added decoder functionality
for Dataframe cube, describe, distinct, drop_duplicates, filter, and
getitem.
  • Loading branch information
sfc-gh-vbudati authored Dec 21, 2024
1 parent d31ec43 commit ca608fb
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 38 deletions.
12 changes: 10 additions & 2 deletions tests/ast/data/RelationalGroupedDataFrame.test
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ df = session.create_dataframe([("SF", 21.0), ("SF", 17.5), ("SF", 24.0), ("NY",

res10 = udtf("_ApplyInPandas", output_schema=PandasDataFrameType(StringType(), FloatType(), FloatType(), "LOCATION", "TEMP_C", "TEMP_F"), input_types=[StringType(), FloatType()], copy_grants=False, _registered_object_name="\"MOCK_DATABASE\".\"MOCK_SCHEMA\".SNOWPARK_TEMP_TABLE_FUNCTION_xxx")

df.group_by("location").apply_in_pandas(convert, StructType([StructField("location", StringType(), nullable=True), StructField("temp_c", FloatType(), nullable=True), StructField("temp_f", FloatType(), nullable=True)], structured=False), input_types=[StringType(), FloatType()], input_names=["LOCATION", "TEMP_C"]).sort("temp_c").collect()
df.group_by("location").apply_in_pandas(convert, StructType([StructField("location", StringType(), nullable=True), StructField("temp_c", FloatType(), nullable=True), StructField("temp_f", FloatType(), nullable=True)], structured=False), input_types=[StringType(), FloatType()], input_names=["LOCATION", "TEMP_C"]).sort("temp_c", ascending=None).collect()

df = session.create_dataframe([(1, "A", 10000, "JAN"), (1, "B", 400, "JAN"), (1, "B", 5000, "FEB")], schema=["empid", "team", "amount", "month"])

df.group_by("empid").pivot("month", values=["JAN", "FEB"]).sum("amount").show()

df.group_by(["empid", "team"]).pivot("month").sum("amount").sort("empid", "team").show()
df.group_by(["empid", "team"]).pivot("month").sum("amount").sort("empid", "team", ascending=None).show()

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -917,6 +917,10 @@ body {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 46
}
}
}
cols {
Expand Down Expand Up @@ -1431,6 +1435,10 @@ body {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 56
}
}
}
cols {
Expand Down
10 changes: 9 additions & 1 deletion tests/ast/data/df_sort.test
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ df = df.sort(col("A"), col("B"), ascending=[0, 1])

df = df.sort(col("A"), col("B"), col("C"), ascending=[0, True, 1])

df = df.sort(col("B"))
df = df.sort(col("B"), ascending=None)

## EXPECTED ENCODED AST

Expand Down Expand Up @@ -485,6 +485,14 @@ body {
assign {
expr {
sp_dataframe_sort {
ascending {
null_val {
src {
file: "SRC_POSITION_TEST_MODE"
start_line: 37
}
}
}
cols {
apply_expr {
fn {
Expand Down
105 changes: 70 additions & 35 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,14 @@ def get_dataframe_analytics_function_column_formatter(
else:
return DataFrameAnalyticsFunctions._default_col_formatter

def decode_col_exprs(self, expr: proto.Expr, is_variadic: bool) -> List[Column]:
def decode_col_exprs(self, expr: proto.Expr) -> List[Column]:
"""
Decode a protobuf object to a list of column expressions.
Parameters
----------
expr : proto.Expr
The protobuf object to decode.
is_variadic : bool
Whether the expression is variadic.
Returns
-------
Expand All @@ -126,7 +124,7 @@ def decode_col_exprs(self, expr: proto.Expr, is_variadic: bool) -> List[Column]:
# Prevent nesting the list in a list if there is only one expression.
# This usually happens when the expression is a list_val.
col_list = self.decode_expr(expr[0])
if not isinstance(col_list, list) and not is_variadic:
if not isinstance(col_list, list):
col_list = [col_list]
else:
col_list = [self.decode_expr(arg) for arg in expr]
Expand Down Expand Up @@ -1045,22 +1043,58 @@ def decode_expr(self, expr: proto.Expr) -> Any:
block=block,
)

case "sp_dataframe_cube":
df = self.decode_expr(expr.sp_dataframe_cube.df)
d = MessageToDict(expr.sp_dataframe_cube.cols)
if "args" not in d:
return df.cube()
cols = self.decode_col_exprs(expr.sp_dataframe_cube.cols.args)
if d.get("variadic", False):
return df.cube(*cols)
else:
return df.cube(cols)

case "sp_dataframe_describe":
df = self.decode_expr(expr.sp_dataframe_describe.df)
d = MessageToDict(expr.sp_dataframe_describe.cols)
if "args" not in d:
return df.describe()
cols = self.decode_col_exprs(expr.sp_dataframe_describe.cols.args)
if d.get("variadic", False):
return df.describe(*cols)
else:
return df.describe(cols)

case "sp_dataframe_distinct":
df = self.decode_expr(expr.sp_dataframe_distinct.df)
return df.distinct()

case "sp_dataframe_drop":
df = self.decode_expr(expr.sp_dataframe_drop.df)
cols = self.decode_col_exprs(
expr.sp_dataframe_drop.cols.args,
expr.sp_dataframe_drop.cols.variadic,
)
if expr.sp_dataframe_group_by.cols.variadic:
cols = self.decode_col_exprs(expr.sp_dataframe_drop.cols.args)
if MessageToDict(expr.sp_dataframe_drop.cols).get("variadic", False):
return df.drop(*cols)
else:
return df.drop(cols)

case "sp_dataframe_drop_duplicates":
df = self.decode_expr(expr.sp_dataframe_drop_duplicates.df)
cols = list(expr.sp_dataframe_drop_duplicates.cols)
if expr.sp_dataframe_drop_duplicates.variadic:
return df.drop_duplicates(*cols)
else:
return df.drop_duplicates(cols)

case "sp_dataframe_except":
df = self.decode_expr(expr.sp_dataframe_except.df)
other = self.decode_expr(expr.sp_dataframe_except.other)
return df.except_(other)

case "sp_dataframe_filter":
df = self.decode_expr(expr.sp_dataframe_filter.df)
condition = self.decode_expr(expr.sp_dataframe_filter.condition)
return df.filter(condition)

case "sp_dataframe_first":
df = self.decode_expr(expr.sp_dataframe_first.df)
block = expr.sp_dataframe_first.block
Expand All @@ -1072,11 +1106,10 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_dataframe_group_by":
df = self.decode_expr(expr.sp_dataframe_group_by.df)
cols = self.decode_col_exprs(
expr.sp_dataframe_group_by.cols.args,
expr.sp_dataframe_group_by.cols.variadic,
)
if expr.sp_dataframe_group_by.cols.variadic:
cols = self.decode_col_exprs(expr.sp_dataframe_group_by.cols.args)
if MessageToDict(expr.sp_dataframe_group_by.cols).get(
"variadic", False
):
return df.group_by(*cols)
else:
return df.group_by(cols)
Expand Down Expand Up @@ -1166,11 +1199,10 @@ def decode_expr(self, expr: proto.Expr) -> Any:
case "sp_dataframe_select__columns":
df = self.decode_expr(expr.sp_dataframe_select__columns.df)
# The columns can be a list of Expr or a single Expr.
cols = self.decode_col_exprs(
expr.sp_dataframe_select__columns.cols,
not hasattr(expr.sp_dataframe_select__columns, "variadic"),
)
if hasattr(expr.sp_dataframe_select__columns, "variadic"):
cols = self.decode_col_exprs(expr.sp_dataframe_select__columns.cols)
if MessageToDict(expr.sp_dataframe_select__columns).get(
"variadic", False
):
val = df.select(*cols)
else:
val = df.select(cols)
Expand All @@ -1189,14 +1221,14 @@ def decode_expr(self, expr: proto.Expr) -> Any:

case "sp_dataframe_sort":
df = self.decode_expr(expr.sp_dataframe_sort.df)
cols = self.decode_col_exprs(
expr.sp_dataframe_sort.cols, expr.sp_dataframe_sort.cols.variadic
cols = list(
self.decode_expr(col) for col in expr.sp_dataframe_sort.cols
)
ascending = self.decode_expr(expr.sp_dataframe_sort.ascending)
if expr.sp_dataframe_sort.cols_variadic:
return df.sort(*cols, ascending)
if MessageToDict(expr.sp_dataframe_sort).get("colsVariadic", False):
return df.sort(*cols, ascending=ascending)
else:
return df.sort(cols, ascending)
return df.sort(cols, ascending=ascending)

case "sp_dataframe_to_df":
df = self.decode_expr(expr.sp_dataframe_to_df.df)
Expand Down Expand Up @@ -1241,10 +1273,11 @@ def decode_expr(self, expr: proto.Expr) -> Any:
expr.sp_relational_grouped_dataframe_agg.grouped_df
)
exprs = self.decode_col_exprs(
expr.sp_relational_grouped_dataframe_agg.exprs.args,
expr.sp_relational_grouped_dataframe_agg.cols.variadic,
expr.sp_relational_grouped_dataframe_agg.exprs.args
)
if expr.sp_relational_grouped_dataframe_agg.exprs.variadic is True:
if MessageToDict(expr.sp_relational_grouped_dataframe_agg.exprs).get(
"variadic", False
):
return grouped_df.agg(*exprs)
else:
return grouped_df.agg(exprs)
Expand All @@ -1263,18 +1296,20 @@ def decode_expr(self, expr: proto.Expr) -> Any:
grouped_df = self.decode_expr(
expr.sp_relational_grouped_dataframe_builtin.grouped_df
)
agg_name = expr.sp_relational_grouped_dataframe_builtin.agg_name
if "cols" not in MessageToDict(
expr.sp_relational_grouped_dataframe_builtin
):
return getattr(grouped_df, agg_name)()
cols = self.decode_col_exprs(
expr.sp_relational_grouped_dataframe_builtin.cols.args,
expr.sp_relational_grouped_dataframe_builtin.cols.variadic,
expr.sp_relational_grouped_dataframe_builtin.cols.args
)
agg_name = expr.sp_relational_grouped_dataframe_builtin.agg_name
if (
expr.sp_relational_grouped_dataframe_builtin.cols.variadic
and isinstance(agg_name, list)
if MessageToDict(expr.sp_relational_grouped_dataframe_builtin.cols).get(
"variadic", False
):
return grouped_df.function(*agg_name)(*cols)
return getattr(grouped_df, agg_name)(*cols)
else:
return grouped_df.function(agg_name)(*cols)
return getattr(grouped_df, agg_name)(cols)

case "sp_relational_grouped_dataframe_ref":
return self.symbol_table[
Expand Down

0 comments on commit ca608fb

Please sign in to comment.