Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1016954: Add an internal API for supporting common table expression (CTE) #1219

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
SnowflakePlanBuilder,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
CTE,
CopyIntoLocationNode,
CopyIntoTableNode,
Limit,
Expand Down Expand Up @@ -1126,6 +1127,9 @@ def do_resolve_with_resolved_children(
if isinstance(logical_plan, Selectable):
return self.plan_builder.select_statement(logical_plan)

if isinstance(logical_plan, CTE):
return self.plan_builder.cte(resolved_children[logical_plan.child])

raise TypeError(
f"Cannot resolve type logical_plan of {type(logical_plan).__name__} to a SnowflakePlan"
)
Expand Down
19 changes: 19 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
INTERSECT = f" {Intersect.sql} "
EXCEPT = f" {Except.sql} "
NOT_NULL = " NOT NULL "
WITH = " WITH "

TEMPORARY_STRING_SET = frozenset(["temporary", "temp"])

Expand Down Expand Up @@ -1382,3 +1383,21 @@ def get_file_format_spec(
file_format_str += FORMAT_NAME + EQUALS + file_format_name
file_format_str += RIGHT_PARENTHESIS
return file_format_str


def cte_statement(query: str, table_name: str) -> str:
return WITH + table_name + AS + LEFT_PARENTHESIS + query + RIGHT_PARENTHESIS


def combine_cte_statements(statements: List[str]) -> str:
# order is maintained
statements_without_dups = list(dict.fromkeys(statements))
if len(statements_without_dups) == 1:
return statements_without_dups[0]
elif len(statements_without_dups) > 1:
return (
statements_without_dups[0]
+ COMMA
+ SPACE
+ COMMA.join(s[len(WITH) :] for s in statements_without_dups[1:])
)
153 changes: 133 additions & 20 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@
import snowflake.connector
import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
SPACE,
aggregate_statement,
attribute_to_schema_string,
batch_insert_into_statement,
combine_cte_statements,
copy_into_location,
copy_into_table,
create_file_format_statement,
create_or_replace_dynamic_table_statement,
create_or_replace_view_statement,
create_table_as_select_statement,
create_table_statement,
cte_statement,
delete_statement,
drop_file_format_if_exists_statement,
drop_table_if_exists_statement,
Expand Down Expand Up @@ -301,14 +304,24 @@ def build(
is_ddl_on_temp_object: bool = False,
) -> SnowflakePlan:
select_child = self.add_result_scan_if_not_select(child)
queries = select_child.queries[:-1] + [
Query(
sql_generator(select_child.queries[-1].sql),
last_query = select_child.queries[-1]
if isinstance(last_query, CTEQuery):
new_query = CTEQuery(
sql_generator(last_query.select_sql),
last_query.with_sql,
query_id_place_holder="",
is_ddl_on_temp_object=is_ddl_on_temp_object,
params=select_child.queries[-1].params,
params=last_query.params,
)
]
else:
new_query = Query(
sql_generator(last_query.sql),
query_id_place_holder="",
is_ddl_on_temp_object=is_ddl_on_temp_object,
params=last_query.params,
)

queries = select_child.queries[:-1] + [new_query]
new_schema_query = (
schema_query if schema_query else sql_generator(child.schema_query)
)
Expand Down Expand Up @@ -369,21 +382,37 @@ def build_binary(
) -> SnowflakePlan:
select_left = self.add_result_scan_if_not_select(left)
select_right = self.add_result_scan_if_not_select(right)
queries = (
select_left.queries[:-1]
+ select_right.queries[:-1]
+ [
Query(
sql_generator(
select_left.queries[-1].sql, select_right.queries[-1].sql
),
params=[
*select_left.queries[-1].params,
*select_right.queries[-1].params,
],
)
]
)
select_left_last_query = select_left.queries[-1]
select_right_last_query = select_right.queries[-1]
params = [*select_left_last_query.params, *select_right_last_query.params]

is_left_cte = isinstance(select_left_last_query, CTEQuery)
is_right_cte = isinstance(select_right_last_query, CTEQuery)
if is_left_cte and is_right_cte:
sql = sql_generator(
select_left_last_query.select_sql, select_right_last_query.select_sql
)
with_sql = combine_cte_statements(
[select_left_last_query.with_sql, select_right_last_query.with_sql]
)
elif is_left_cte:
sql = sql_generator(
select_left_last_query.select_sql, select_right_last_query.sql
)
with_sql = select_left_last_query.with_sql
elif is_right_cte:
sql = sql_generator(
select_left_last_query.sql, select_right_last_query.select_sql
)
with_sql = select_right_last_query.with_sql
else:
sql = sql_generator(select_left_last_query.sql, select_right_last_query.sql)
with_sql = None
if with_sql:
last_query = CTEQuery(sql, with_sql, params=params)
else:
last_query = Query(sql, params=params)
queries = select_left.queries[:-1] + select_right.queries[:-1] + [last_query]

left_schema_query = schema_value_statement(select_left.attributes)
right_schema_query = schema_value_statement(select_right.attributes)
Expand Down Expand Up @@ -1205,6 +1234,33 @@ def add_result_scan_if_not_select(self, plan: SnowflakePlan) -> SnowflakePlan:
session=self.session,
)

def cte(self, plan: SnowflakePlan) -> SnowflakePlan:
last_query = plan.queries[-1]
# CTE only supports select query
assert is_sql_select_statement(last_query.sql)
if isinstance(last_query, CTEQuery):
# once the frontend CTE API is called, we always move the select sql to be a CTE
new_query = last_query.move_select_sql_to_cte()
else:
new_query = CTEQuery(
last_query.sql,
query_id_place_holder=last_query.query_id_place_holder,
is_ddl_on_temp_object=last_query.is_ddl_on_temp_object,
params=last_query.params,
)
queries = plan.queries[:-1] + [new_query]
return SnowflakePlan(
queries,
plan.schema_query,
plan.post_actions,
plan.expr_to_alias,
plan.source_plan,
plan.is_ddl_on_temp_object,
api_calls=plan.api_calls,
df_aliased_col_name_to_real_col_name=plan.df_aliased_col_name_to_real_col_name,
session=self.session,
)


class Query:
def __init__(
Expand Down Expand Up @@ -1250,3 +1306,60 @@ def __init__(
) -> None:
super().__init__(sql)
self.rows = rows


class CTEQuery(Query):
"""
CTEQuery class has two extra attributes, `with_sql` and `select_sql`, which are broken down
from the full sql (`sql`).
"""

def __init__(
self,
select_sql: str,
with_sql: Optional[str] = None,
*,
query_id_place_holder: Optional[str] = None,
is_ddl_on_temp_object: bool = False,
params: Optional[Sequence[Any]] = None,
) -> None:
if with_sql:
self.with_sql = with_sql
self.select_sql = select_sql
else:
# When with_sql is None, we convert the plain select sql to
# a with sql + a select * from table sql
self.with_sql, self.select_sql = CTEQuery.generate_with_and_select_sql(
select_sql
)
super().__init__(
self.with_sql + SPACE + self.select_sql,
query_id_place_holder=query_id_place_holder,
is_ddl_on_temp_object=is_ddl_on_temp_object,
params=params,
)

@staticmethod
def generate_with_and_select_sql(sql: str) -> Tuple[str, str]:
# The rewritten logic is (for example),
# select a, b from t1 ->
# with t2 as (select a, b from t1) select * from t2
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
select_sql = project_statement([], temp_table_name)
with_sql = cte_statement(sql, temp_table_name)
return with_sql, select_sql

def move_select_sql_to_cte(self) -> "CTEQuery":
# The rewritten logic is (for example),
# with t2 as (select a, b from t1) select a, b from t2 ->
# with t3 as (select a, b from t2), t2 as (select a, b from t1) select * from t3
# and it can actually be called recursively
with_sql, select_sql = CTEQuery.generate_with_and_select_sql(self.select_sql)
with_sql = combine_cte_statements([self.with_sql, with_sql])
return CTEQuery(
select_sql,
with_sql,
query_id_place_holder=self.query_id_place_holder,
is_ddl_on_temp_object=self.is_ddl_on_temp_object,
params=self.params,
)
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,10 @@ def __init__(
self.file_format_name = file_format_name
self.file_format_type = file_format_type
self.copy_options = copy_options


class CTE(LogicalPlan):
def __init__(self, child: LogicalPlan) -> None:
super().__init__()
self.child = child
self.children.append(child)
5 changes: 5 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
SelectTableFunction,
)
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
CTE,
CopyIntoTableNode,
Limit,
LogicalPlan,
Expand Down Expand Up @@ -3737,6 +3738,10 @@ def random_split(
]
return res_dfs

def _as_cte(self) -> "DataFrame":
"""Returns a new DataFrame where select query is rewritten using common table expressions (CTE)."""
return self._with_plan(CTE(self._plan))

@property
def queries(self) -> Dict[str, List[str]]:
"""
Expand Down
110 changes: 110 additions & 0 deletions tests/integ/test_cte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import re

import pytest

from snowflake.snowpark._internal.analyzer import analyzer
from snowflake.snowpark._internal.utils import TEMP_OBJECT_NAME_PREFIX
from snowflake.snowpark.functions import col
from tests.utils import Utils

WITH = "WITH"
original_threshold = analyzer.ARRAY_BIND_THRESHOLD


@pytest.fixture(autouse=True)
def cleanup(session):
yield
analyzer.ARRAY_BIND_THRESHOLD = original_threshold


@pytest.fixture(params=[False, True])
def has_multi_queries(request):
return request.param


@pytest.fixture(scope="function")
def df(session, has_multi_queries):
# TODO SNOW-1020742: Integerate CTE support with sql simplifier
session.sql_simplifier_enabled = False
if has_multi_queries:
analyzer.ARRAY_BIND_THRESHOLD = 2
else:
analyzer.ARRAY_BIND_THRESHOLD = original_threshold
return session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])


def is_select_star_from_table_query(query):
# check whether the query ends with `SELECT * FROM (SNOWPARK_TEMP_TABLE_XXX)`
pattern = re.compile(
rf"SELECT\s+\*\s+FROM \({TEMP_OBJECT_NAME_PREFIX}TABLE_[0-9A-Z]+\)$"
)
return bool(pattern.search(query))


def check_result(df_result, df_cte_result):
Utils.check_answer(df_result, df_cte_result)
last_query = df_cte_result.queries["queries"][-1]
assert last_query.startswith(WITH)
assert last_query.count(WITH) == 1

# move the select sql to CTE
df_cte_result2 = df_cte_result._as_cte()
Utils.check_answer(df_result, df_cte_result2)
last_query = df_cte_result2.queries["queries"][-1]
assert last_query.startswith(WITH)
assert last_query.count(WITH) == 1
assert is_select_star_from_table_query(last_query)


@pytest.mark.parametrize(
"action",
[
lambda x: x,
lambda x: x.select("a"),
lambda x: x.select("a", "b"),
lambda x: x.select("*"),
lambda x: x.select("a", "b").select("b"),
lambda x: x.filter(col("a") == 1),
lambda x: x.filter(col("a") == 1).select("b"),
lambda x: x.select("a").filter(col("a") == 1),
lambda x: x.sort("a", ascending=False),
lambda x: x.filter(col("a") == 1).sort("a"),
lambda x: x.limit(1),
lambda x: x.sort("a").limit(1),
lambda x: x.drop("b"),
lambda x: x.select("a", "b").drop("b"),
lambda x: x.agg({"a": "count", "b": "sum"}),
lambda x: x.group_by("a").min("b"),
],
)
def test_basic(session, df, action):
df_cte = df._as_cte()
df_result, df_cte_result = action(df), action(df_cte)
check_result(df_result, df_cte_result)


@pytest.mark.parametrize(
"action",
[
lambda x, y: x.union_all(y),
lambda x, y: x.select("a").union_all(y.select("a")),
lambda x, y: x.except_(y),
lambda x, y: x.select("a").except_(y.select("a")),
lambda x, y: x.join(y.select("a", "b"), rsuffix="_y"),
lambda x, y: x.select("a").join(y, rsuffix="_y"),
lambda x, y: x.join(y.select("a"), rsuffix="_y"),
],
)
def test_binary(session, df, action):
df_cte = df._as_cte()
df_result = action(df, df)
for df_cte_result in [
action(df_cte, df),
action(df, df_cte),
action(df_cte, df_cte),
]:
check_result(df_result, df_cte_result)
Loading