Skip to content

Commit

Permalink
[SNOW-1632895] Add derive_dependent_columns_with_duplication capabili…
Browse files Browse the repository at this point in the history
…ty (#2272)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1632895

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

in order to handle nested select with column dependency (not handled by
sql simplifier) complexity calculation, added a utility support for
deriving all column depenencies with duplication, for example, col('a')
+ col('b') + 3*('a'), should return dependency ['a', 'b', 'a'] this
provides both information about the columns it dependents on and also
the number of times/
  • Loading branch information
sfc-gh-yzou authored Sep 16, 2024
1 parent 64ced96 commit 0ee3033
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import AbstractSet, Optional
from typing import AbstractSet, List, Optional

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand All @@ -29,6 +30,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.left, self.right)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.left, self.right)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand Down
96 changes: 94 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
def derive_dependent_columns(
*expressions: "Optional[Expression]",
) -> Optional[AbstractSet[str]]:
"""
Given set of expressions, derive the set of columns that the expressions dependents on.
Note, the returned dependent columns is a set without duplication. For example, given expression
concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has
occurred in the given expression twice.
"""
result = set()
for exp in expressions:
if exp is not None:
Expand All @@ -48,6 +55,23 @@ def derive_dependent_columns(
return result


def derive_dependent_columns_with_duplication(
*expressions: "Optional[Expression]",
) -> List[str]:
"""
Given set of expressions, derive the list of columns that the expression dependents on.
Note, the returned columns will have duplication if the column occurred more than once in
the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result
[col1, col1, col2], where col1 occurred twice in the result.
"""
result = []
for exp in expressions:
if exp is not None:
result.extend(exp.dependent_column_names_with_duplication())
return result


class Expression:
"""Consider removing attributes, and adding properties and methods.
A subclass of Expression may have no child, one child, or multiple children.
Expand All @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
# TODO: consider adding it to __init__ or use cached_property.
return COLUMN_DEPENDENCY_EMPTY

def dependent_column_names_with_duplication(self) -> List[str]:
return []

@property
def pretty_name(self) -> str:
"""Returns a user-facing string representation of this expression's name.
Expand Down Expand Up @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return COLUMN_DEPENDENCY_DOLLAR

def dependent_column_names_with_duplication(self) -> List[str]:
return list(COLUMN_DEPENDENCY_DOLLAR)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return self.plan.cumulative_node_complexity
Expand All @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.expressions)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.expressions)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
Expand All @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.columns, *self.values)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.columns, *self.values)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.IN
Expand Down Expand Up @@ -212,6 +248,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return {self.name}

def dependent_column_names_with_duplication(self) -> List[str]:
return [self.name]

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN
Expand All @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
else COLUMN_DEPENDENCY_ALL
)

def dependent_column_names_with_duplication(self) -> List[str]:
return (
derive_dependent_columns_with_duplication(*self.expressions)
if self.expressions
else [] # we currently do not handle * dependency
)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1}
Expand Down Expand Up @@ -278,6 +324,14 @@ def __hash__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return self._dependent_column_names

def dependent_column_names_with_duplication(self) -> List[str]:
return (
[]
if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL)
or (self._dependent_column_names is None)
else list(self._dependent_column_names)
)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN
Expand Down Expand Up @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.pattern)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, self.pattern)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr LIKE pattern
Expand Down Expand Up @@ -400,6 +457,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.pattern)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, self.pattern)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr REG_EXP pattern
Expand All @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr COLLATE collate_spec
Expand All @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr)

@property
def plan_node_category(self) -> PlanNodeCategory:
# the literal corresponds to the contribution from self.field
Expand All @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr)

@property
def plan_node_category(self) -> PlanNodeCategory:
# the literal corresponds to the contribution from self.field
Expand Down Expand Up @@ -510,6 +579,9 @@ def sql(self) -> str:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.children)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.children)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
Expand All @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, *self.order_by_cols)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr WITHIN GROUP (ORDER BY cols)
Expand All @@ -549,13 +624,21 @@ def __init__(
self.branches = branches
self.else_value = else_value

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
@property
def _child_expressions(self) -> List[Expression]:
exps = []
for exp_tuple in self.branches:
exps.extend(exp_tuple)
if self.else_value is not None:
exps.append(self.else_value)
return derive_dependent_columns(*exps)

return exps

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self._child_expressions)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self._child_expressions)

@property
def plan_node_category(self) -> PlanNodeCategory:
Expand Down Expand Up @@ -602,6 +685,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.children)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.children)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
Expand All @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.col)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.col)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
Expand All @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.exprs)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.exprs)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/grouping_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand All @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.group_by_exprs)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.group_by_exprs)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand All @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
flattened_args = [exp for sublist in self.args for exp in sublist]
return derive_dependent_columns(*flattened_args)

def dependent_column_names_with_duplication(self) -> List[str]:
flattened_args = [exp for sublist in self.args for exp in sublist]
return derive_dependent_columns_with_duplication(*flattened_args)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
Expand Down
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/_internal/analyzer/sort_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import AbstractSet, Optional, Type
from typing import AbstractSet, List, Optional, Type

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)


Expand Down Expand Up @@ -55,3 +56,6 @@ def __init__(

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.child)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.child)
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import AbstractSet, Dict, Optional
from typing import AbstractSet, Dict, List, Optional

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
NamedExpression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand Down Expand Up @@ -36,6 +37,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.child)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.child)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand Down
17 changes: 17 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/window_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand Down Expand Up @@ -71,6 +72,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.lower, self.upper)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.lower, self.upper)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand Down Expand Up @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
*self.partition_spec, *self.order_spec, self.frame_spec
)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(
*self.partition_spec, *self.order_spec, self.frame_spec
)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# partition_spec order_by_spec frame_spec
Expand Down Expand Up @@ -138,6 +147,11 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.window_function, self.window_spec)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(
self.window_function, self.window_spec
)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.WINDOW
Expand Down Expand Up @@ -171,6 +185,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.default)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, self.default)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# for func_name
Expand Down
Loading

0 comments on commit 0ee3033

Please sign in to comment.