From 0ee303348abf772fb161d1edf4f75728089da759 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Mon, 16 Sep 2024 11:01:57 -0700 Subject: [PATCH] [SNOW-1632895] Add derive_dependent_columns_with_duplication capability (#2272) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1632895 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. in order to handle nested select with column dependency (not handled by sql simplifier) complexity calculation, added a utility support for deriving all column depenencies with duplication, for example, col('a') + col('b') + 3*('a'), should return dependency ['a', 'b', 'a'] this provides both information about the columns it dependents on and also the number of times/ --- .../_internal/analyzer/binary_expression.py | 6 +- .../snowpark/_internal/analyzer/expression.py | 96 +++++++++++++- .../_internal/analyzer/grouping_set.py | 8 ++ .../_internal/analyzer/sort_expression.py | 6 +- .../_internal/analyzer/unary_expression.py | 6 +- .../_internal/analyzer/window_expression.py | 17 +++ .../unit/test_expression_dependent_columns.py | 125 ++++++++++++++++++ 7 files changed, 259 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 3ed969caada..22591f55e47 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -29,6 +30,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.left, self.right) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a2d21db4eb2..a7cb5fd97a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -35,6 +35,13 @@ def derive_dependent_columns( *expressions: "Optional[Expression]", ) -> Optional[AbstractSet[str]]: + """ + Given set of expressions, derive the set of columns that the expressions dependents on. + + Note, the returned dependent columns is a set without duplication. For example, given expression + concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has + occurred in the given expression twice. + """ result = set() for exp in expressions: if exp is not None: @@ -48,6 +55,23 @@ def derive_dependent_columns( return result +def derive_dependent_columns_with_duplication( + *expressions: "Optional[Expression]", +) -> List[str]: + """ + Given set of expressions, derive the list of columns that the expression dependents on. + + Note, the returned columns will have duplication if the column occurred more than once in + the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result + [col1, col1, col2], where col1 occurred twice in the result. + """ + result = [] + for exp in expressions: + if exp is not None: + result.extend(exp.dependent_column_names_with_duplication()) + return result + + class Expression: """Consider removing attributes, and adding properties and methods. A subclass of Expression may have no child, one child, or multiple children. @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY + def dependent_column_names_with_duplication(self) -> List[str]: + return [] + @property def pretty_name(self) -> str: """Returns a user-facing string representation of this expression's name. @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + def dependent_column_names_with_duplication(self) -> List[str]: + return list(COLUMN_DEPENDENCY_DOLLAR) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.expressions) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.columns, *self.values) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN @@ -212,6 +248,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + def dependent_column_names_with_duplication(self) -> List[str]: + return [self.name] + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + derive_dependent_columns_with_duplication(*self.expressions) + if self.expressions + else [] # we currently do not handle * dependency + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} @@ -278,6 +324,14 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + def dependent_column_names_with_duplication(self) -> List[str]: + return ( + [] + if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL) + or (self._dependent_column_names is None) + else list(self._dependent_column_names) + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern @@ -400,6 +457,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -510,6 +579,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) + @property def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) @@ -549,13 +624,21 @@ def __init__( self.branches = branches self.else_value = else_value - def dependent_column_names(self) -> Optional[AbstractSet[str]]: + @property + def _child_expressions(self) -> List[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) if self.else_value is not None: exps.append(self.else_value) - return derive_dependent_columns(*exps) + + return exps + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self._child_expressions) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self._child_expressions) @property def plan_node_category(self) -> PlanNodeCategory: @@ -602,6 +685,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.col) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.exprs) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 84cd63fd87d..012940471d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(*self.group_by_exprs) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + def dependent_column_names_with_duplication(self) -> List[str]: + flattened_args = [exp for sublist in self.args for exp in sublist] + return derive_dependent_columns_with_duplication(*flattened_args) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..82451245e4c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,11 +2,12 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, List, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) @@ -55,3 +56,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5886e11069..1ae08e8fde2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -36,6 +37,9 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.child) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 69db3f265ce..4381c4a2e22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -71,6 +72,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.lower, self.upper) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + *self.partition_spec, *self.order_spec, self.frame_spec + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec @@ -138,6 +147,11 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication( + self.window_function, self.window_spec + ) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW @@ -171,6 +185,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + def dependent_column_names_with_duplication(self) -> List[str]: + return derive_dependent_columns_with_duplication(self.expr, self.default) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index c31e5cc6290..c9b8a1ce38d 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -87,30 +87,37 @@ def test_expression(): a = Expression() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] b = Expression(child=UnresolvedAttribute("a")) assert b.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert b.dependent_column_names_with_duplication() == [] # root class Expression always returns empty dependency def test_literal(): a = Literal(5) assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] def test_attribute(): a = Attribute("A", IntegerType()) assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] def test_unresolved_attribute(): a = UnresolvedAttribute("A") assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] b = UnresolvedAttribute("a > 1", is_sql_text=True) assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] c = UnresolvedAttribute("$1 > 1", is_sql_text=True) assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert c.dependent_column_names_with_duplication() == ["$"] def test_case_when(): @@ -118,46 +125,85 @@ def test_case_when(): b = Column("b") z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e")) assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'} + # verify column '"A"', '"B"' occurred twice in the dependency columns + assert z._expression.dependent_column_names_with_duplication() == [ + '"A"', + '"B"', + '"C"', + '"A"', + '"B"', + '"D"', + '"E"', + ] def test_collate(): a = Collate(UnresolvedAttribute("a"), "spec") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_function_expression(): a = FunctionExpression("test_func", [UnresolvedAttribute(x) for x in "abcd"], False) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # expressions with duplicated dependent column + b = FunctionExpression( + "test_func", [UnresolvedAttribute(x) for x in "abcdad"], False + ) + assert b.dependent_column_names() == set("abcd") + assert b.dependent_column_names_with_duplication() == list("abcdad") def test_in_expression(): a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") def test_like(): a = Like(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = Like(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_list_agg(): a = ListAgg(UnresolvedAttribute("a"), ",", True) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_multiple_expression(): a = MultipleExpression([UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + a = MultipleExpression([UnresolvedAttribute(x) for x in "abcdbea"]) + assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("abcdbea") def test_reg_exp(): a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + b = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] def test_scalar_subquery(): a = ScalarSubquery(None) assert a.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert a.dependent_column_names_with_duplication() == list(COLUMN_DEPENDENCY_DOLLAR) def test_snowflake_udf(): @@ -165,21 +211,42 @@ def test_snowflake_udf(): "udf_name", [UnresolvedAttribute(x) for x in "abcd"], IntegerType() ) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + # with duplication + b = SnowflakeUDF( + "udf_name", [UnresolvedAttribute(x) for x in "abcdfc"], IntegerType() + ) + assert b.dependent_column_names() == set("abcdf") + assert b.dependent_column_names_with_duplication() == list("abcdfc") def test_star(): a = Star([Attribute(x, IntegerType()) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") + + b = Star([]) + assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] def test_subfield_string(): a = SubfieldString(UnresolvedAttribute("a"), "field") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_within_group(): a = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") + + b = WithinGroup( + UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"] + ) + assert b.dependent_column_names() == set("abcde") + assert b.dependent_column_names_with_duplication() == list("eabcdea") @pytest.mark.parametrize( @@ -189,16 +256,19 @@ def test_within_group(): def test_unary_expression(expression_class): a = expression_class(child=UnresolvedAttribute("a")) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_alias(): a = Alias(child=Add(UnresolvedAttribute("a"), UnresolvedAttribute("b")), name="c") assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_cast(): a = Cast(UnresolvedAttribute("a"), IntegerType()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] @pytest.mark.parametrize( @@ -234,6 +304,19 @@ def test_binary_expression(expression_class): assert b.dependent_column_names() == {"B"} assert binary_expression.dependent_column_names() == {"A", "B"} + assert a.dependent_column_names_with_duplication() == ["A"] + assert b.dependent_column_names_with_duplication() == ["B"] + assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"] + + # hierarchical expressions with duplication + hierarchical_binary_expression = expression_class(expression_class(a, b), b) + assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"} + assert hierarchical_binary_expression.dependent_column_names_with_duplication() == [ + "A", + "B", + "B", + ] + @pytest.mark.parametrize( "expression_class", @@ -253,6 +336,18 @@ def test_grouping_set(expression_class): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] + + # with duplication + b = expression_class( + [ + UnresolvedAttribute("a"), + UnresolvedAttribute("a"), + UnresolvedAttribute("c"), + ] + ) + assert b.dependent_column_names() == {"a", "c"} + assert b.dependent_column_names_with_duplication() == ["a", "a", "c"] def test_grouping_sets_expression(): @@ -263,11 +358,13 @@ def test_grouping_sets_expression(): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] def test_sort_order(): a = SortOrder(UnresolvedAttribute("a"), Ascending()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_specified_window_frame(): @@ -275,12 +372,21 @@ def test_specified_window_frame(): RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("b") ) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] + + # with duplication + b = SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("a") + ) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] @pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead]) def test_rank_related_function_expression(expression_class): a = expression_class(UnresolvedAttribute("a"), 1, UnresolvedAttribute("b"), False) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_window_spec_definition(): @@ -295,6 +401,7 @@ def test_window_spec_definition(): ), ) assert a.dependent_column_names() == set("abcdef") + assert a.dependent_column_names_with_duplication() == list("abcdef") def test_window_expression(): @@ -310,6 +417,23 @@ def test_window_expression(): ) a = WindowExpression(UnresolvedAttribute("x"), window_spec_definition) assert a.dependent_column_names() == set("abcdefx") + assert a.dependent_column_names_with_duplication() == list("xabcdef") + + +def test_window_expression_with_duplication_columns(): + window_spec_definition = WindowSpecDefinition( + [UnresolvedAttribute("a"), UnresolvedAttribute("b")], + [ + SortOrder(UnresolvedAttribute("c"), Ascending()), + SortOrder(UnresolvedAttribute("a"), Ascending()), + ], + SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("e"), UnresolvedAttribute("f") + ), + ) + a = WindowExpression(UnresolvedAttribute("e"), window_spec_definition) + assert a.dependent_column_names() == set("abcef") + assert a.dependent_column_names_with_duplication() == list("eabcaef") @pytest.mark.parametrize( @@ -325,3 +449,4 @@ def test_window_expression(): def test_other_window_expressions(expression_class): a = expression_class() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == []