Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Adds MERGE INTO transform #109

Merged
merged 32 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c842900
feat: Adds MERGE INTO transform
jsibbison-square Jun 27, 2024
ab86bf1
Merge branch 'main' into jsibbison-20240627-merge-into
tekumara Jun 29, 2024
40ada94
add test_merge using test case from snowflake docs
tekumara Jun 29, 2024
13a64cb
Merge branch 'main' into jsibbison-20240627-merge-into
jsibbison-square Jul 11, 2024
2eecaea
WIP: working on merge
jsibbison-square Jul 11, 2024
d5d72a2
Merge branch 'main' into jsibbison-20240627-merge-into
jsibbison-square Jul 11, 2024
a580635
implements the return values for merge statement
jsibbison-square Jul 14, 2024
cd38095
Removes old merge implementation
jsibbison-square Jul 15, 2024
0cbeb50
Merge branch 'main' into jsibbison-20240627-merge-into
jsibbison-square Jul 22, 2024
4310802
Adds test for NOT MATCHED INSERT with an AND condition
jsibbison-square Jul 22, 2024
10aa014
Merge branch 'main' into jsibbison-20240627-merge-into
jsibbison-square Aug 28, 2024
2ded345
Ruff fixes
jsibbison-square Aug 28, 2024
71361fd
ruff format
jsibbison-square Aug 28, 2024
dd8923c
fix typos
jsibbison-square Aug 28, 2024
533ca68
Fix pyright in cursor.py
jsibbison-square Aug 28, 2024
8f19190
Extract merge transform to separate file for further refactoring
jsibbison-square Aug 30, 2024
d5c3e88
Extract some outer expressions
jsibbison-square Aug 30, 2024
de6385b
extract remove_table_alias helper
jsibbison-square Aug 30, 2024
d85a7b2
Extract temp merge insertion func
jsibbison-square Aug 30, 2024
b39e2e5
Merge branch 'main' into jsibbison-20240627-merge-into
jsibbison-square Sep 4, 2024
4cd831c
Extract merge tests from test_fakes
jsibbison-square Sep 4, 2024
71457e8
Add documentation to the contents of the merge transformation
jsibbison-square Sep 5, 2024
6905ac3
some ruff fixes
jsibbison-square Sep 5, 2024
7cf899a
more ruff fixes
jsibbison-square Sep 5, 2024
ac83fed
pyright fixes
jsibbison-square Sep 5, 2024
97f6e2d
ruff format fix
jsibbison-square Sep 5, 2024
6763398
remove todo
jsibbison-square Sep 5, 2024
0c71782
Remove debugging print
jsibbison-square Sep 5, 2024
3e3e9c3
add transform test
tekumara Sep 7, 2024
63948c1
refactor class to pure functions
tekumara Sep 7, 2024
dfa68e1
unalias
tekumara Sep 8, 2024
458b545
Revert "unalias"
tekumara Sep 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions fakesnow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def execute(
) -> FakeSnowflakeCursor:
try:
self._sqlstate = None
last_execute_result = None

if os.environ.get("FAKESNOW_DEBUG") == "snowflake":
print(f"{command};{params=}" if params else f"{command};", file=sys.stderr)
Expand All @@ -139,10 +140,15 @@ def execute(
command, params = self._rewrite_with_params(command, params)
if self._conn.nop_regexes and any(re.match(p, command, re.IGNORECASE) for p in self._conn.nop_regexes):
transformed = transforms.SUCCESS_NOP
last_execute_result = self._execute(transformed, params)
else:
expression = parse_one(command, read="snowflake")
transformed = self._transform(expression)
return self._execute(transformed, params)
for exp in self._split_transform(expression):
transformed = self._transform(exp)
last_execute_result = self._execute(transformed, params)

assert last_execute_result is not None
return last_execute_result
except snowflake.connector.errors.ProgrammingError as e:
self._sqlstate = e.sqlstate
raise e
Expand Down Expand Up @@ -207,6 +213,11 @@ def _transform(self, expression: exp.Expression) -> exp.Expression:
.transform(transforms.alter_table_strip_cluster_by)
)

def _split_transform(self, expression: exp.Expression) -> list[exp.Expression]:
# Applies transformations that require splitting the expression into multiple expressions
# Split transforms have limited support at the moment.
return transforms.merge(expression)

def _execute(
self, transformed: exp.Expression, params: Sequence[Any] | dict[Any, Any] | None = None
) -> FakeSnowflakeCursor:
Expand Down
5 changes: 5 additions & 0 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sqlglot
from sqlglot import exp

from fakesnow import transforms_merge
from fakesnow.instance import USERS_TABLE_FQ_NAME
from fakesnow.variables import Variables

Expand Down Expand Up @@ -691,6 +692,10 @@ def json_extract_precedence(expression: exp.Expression) -> exp.Expression:
return expression


def merge(expression: exp.Expression) -> list[exp.Expression]:
return transforms_merge.merge(expression)


def random(expression: exp.Expression) -> exp.Expression:
"""Convert random() and random(seed).

Expand Down
257 changes: 257 additions & 0 deletions fakesnow/transforms_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import sqlglot
from sqlglot import exp

# Implements snowflake's MERGE INTO functionality in duckdb (https://docs.snowflake.com/en/sql-reference/sql/merge).

TEMP_MERGE_UPDATED_DELETES = "temp_merge_updates_deletes"
TEMP_MERGE_INSERTS = "temp_merge_inserts"


def _target_table(merge_expr: exp.Expression) -> exp.Expression:
target_table = merge_expr.this
if target_table is None:
raise ValueError("Target table expression is None")
return target_table


def _source_table(merge_expr: exp.Expression) -> exp.Expression:
source_table = merge_expr.args.get("using")
if source_table is None:
raise ValueError("Source table expression is None")
return source_table


def _merge_on_expr(merge_expr: exp.Expression) -> exp.Expression:
# Get the ON expression from the merge operation (Eg. MERGE INTO t1 USING t2 ON t1.t1Key = t2.t2Key)
# which is applied to filter the source and target tables.
on_expr = merge_expr.args.get("on")
if on_expr is None:
raise ValueError("Merge ON expression is None")
return on_expr


def _create_temp_tables() -> list[exp.Expression]:
# Creates temp tables to store the source rows and the modifications to apply to those rows.
return [
# Create temp table for update and delete operations
sqlglot.parse_one(
f"CREATE OR REPLACE TEMP TABLE {TEMP_MERGE_UPDATED_DELETES} "
+ "(target_rowid INTEGER, when_id INTEGER, type CHAR(1))"
),
# Create temp table for insert operations
sqlglot.parse_one(f"CREATE OR REPLACE TEMP TABLE {TEMP_MERGE_INSERTS} (source_rowid INTEGER, when_id INTEGER)"),
]


def _generate_final_expression_set(
temp_table_inserts: list[exp.Expression], output_expressions: list[exp.Expression]
) -> list[exp.Expression]:
# Collects all of the operations together to be executed in a single transaction.
# Operate in a transaction to freeze rowids https://duckdb.org/docs/sql/statements/select#row-ids
begin_transaction_exp = sqlglot.parse_one("BEGIN TRANSACTION")
end_transaction_exp = sqlglot.parse_one("COMMIT")

# Add modifications results outcome query
results_exp = sqlglot.parse_one(f"""
WITH merge_update_deletes AS (
select count_if(type == 'U')::int AS "updates", count_if(type == 'D')::int as "deletes"
from {TEMP_MERGE_UPDATED_DELETES}
), merge_inserts AS (select count() AS "inserts" from {TEMP_MERGE_INSERTS})
SELECT mi.inserts as "number of rows inserted",
mud.updates as "number of rows updated",
mud.deletes as "number of rows deleted"
from merge_update_deletes mud, merge_inserts mi
""")

return [
begin_transaction_exp,
*temp_table_inserts,
*output_expressions,
end_transaction_exp,
results_exp,
]


def _insert_temp_merge_operation(
op_type: str, when_idx: int, subquery: exp.Expression, target_table: exp.Expression
) -> exp.Expression:
assert op_type in {
"U",
"D",
}, f"Expected 'U' or 'D', got merge op_type: {op_type}" # Updates/Deletes
return exp.insert(
into=TEMP_MERGE_UPDATED_DELETES,
expression=exp.select("rowid", exp.Literal.number(when_idx), exp.Literal.string(op_type))
.from_(target_table)
.where(subquery),
)


def _remove_table_alias(eq_exp: exp.Condition) -> exp.Condition:
eq_exp.set("table", None)
return eq_exp


def merge(merge_expr: exp.Expression) -> list[exp.Expression]:
"""
Create multiple compatible duckdb statements to be functionally equivalent to Snowflake's MERGE INTO.
Snowflake's MERGE INTO: See https://docs.snowflake.com/en/sql-reference/sql/merge.html
"""
# Breaking down how we implement MERGE:
# Insert into a temp table the source rows (rowid is stable in a transaction: https://duckdb.org/docs/sql/statements/select.html#row-ids)
# and which modification to apply to those source rows (insert, update, delete).
#
# Then apply the modifications to the target table based on the rowids in the temp tables.

# TODO: Error if attempting to update the same source row multiple times (based on a config in the doco).
if not isinstance(merge_expr, exp.Merge):
return [merge_expr]

temp_table_inserts = _create_temp_tables()
output_expressions = []

target_tbl = _target_table(merge_expr)
source_tbl = _source_table(merge_expr)
on_expr = _merge_on_expr(merge_expr)

whens = merge_expr.expressions
# Loop through each WHEN clause
for w_idx, w in enumerate(whens):
assert isinstance(w, exp.When), f"Expected When expression, got {w}"

# Combine the top level ON expression with the AND condition
# from this specific WHEN into a subquery, we use to target rows.
# Eg. # MERGE INTO t1 USING t2 ON t1.t1Key = t2.t2Key
# WHEN MATCHED AND t2.marked = 1 THEN DELETE
and_condition = w.args.get("condition")
subquery_on_expression = on_expr.copy()
if and_condition:
subquery_on_expression = exp.And(this=subquery_on_expression, expression=and_condition)

matched = w.args.get("matched")
then = w.args.get("then")
assert then is not None, "then is None"
# Handling WHEN MATCHED AND <Condition> THEN DELETE / UPDATE SET <Updates>
if matched:
# Ensuring rows already exist in temporary table for
# previous WHEN clauses are not added again
not_in_temp_table_subquery = exp.Not(
this=exp.Exists(
this=exp.select(exp.Literal.number(1))
.from_(TEMP_MERGE_UPDATED_DELETES)
.where(
exp.EQ(
this=exp.Column(this="rowid", table=target_tbl),
expression=exp.Column(this="target_rowid"),
)
)
)
)

# Query finding rows that match the original ON condition and this WHEN condition
subquery_ignoring_temp_table = exp.Exists(
this=exp.select(exp.Literal.number(1)).from_(source_tbl).where(subquery_on_expression)
)
# Include both of the above subqueries in the final subquery
subquery = exp.And(this=subquery_ignoring_temp_table, expression=not_in_temp_table_subquery)

# Select the rowids of the rows to apply the operation to
rowid_in_temp_table_expr = exp.In(
this=exp.Column(this="rowid", table=target_tbl),
expressions=[
exp.select("target_rowid")
.from_(TEMP_MERGE_UPDATED_DELETES)
.where(exp.EQ(this="when_id", expression=exp.Literal.number(w_idx)))
.where(exp.EQ(this="target_rowid", expression=exp.Column(this="rowid", table=target_tbl)))
],
)
if isinstance(then, exp.Update):
# Insert into the temp table the rowids of the rows that match the WHEN condition
temp_table_inserts.append(_insert_temp_merge_operation("U", w_idx, subquery, target_tbl))

# Build the UPDATE statement to apply to the target table for this specific WHEN clause sourcing
# its target rows from the temp table that we just inserted into.
then.set("this", target_tbl)
then_exprs = then.args.get("expressions")
assert then_exprs is not None, "then_exprs is None"
then.set(
"expressions",
exp.Set(expressions=[_remove_table_alias(e) for e in then_exprs]),
)
then.set("from", exp.From(this=source_tbl))
then.set(
"where",
exp.Where(this=exp.And(this=subquery_on_expression, expression=rowid_in_temp_table_expr)),
)
output_expressions.append(then)

# Handling WHEN MATCHED AND <Condition> THEN DELETE
elif then.args.get("this") == "DELETE":
# Insert into the temp table the rowids of the rows that match the WHEN condition
temp_table_inserts.append(_insert_temp_merge_operation("D", w_idx, subquery, target_tbl))

# Build the DELETE statement to apply to the target table for this specific WHEN clause sourcing
# its target rows from the temp table that we just inserted into.
delete_from_temp = exp.delete(table=target_tbl, where=rowid_in_temp_table_expr)
output_expressions.append(delete_from_temp)
else:
assert isinstance(then, (exp.Update, exp.Delete)), f"Expected 'Update' or 'Delete', got {then}"

# Handling WHEN NOT MATCHED THEN INSERT (<Columns>) VALUES (<Values>)
else:
assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
# Query ensuring rows already exist in the temporary table for
# previous WHEN clauses are not added again
not_in_temp_table_subquery = exp.Not(
this=exp.Exists(
this=exp.select(exp.Literal.number(1))
.from_(TEMP_MERGE_INSERTS)
.where(
exp.EQ(
this=exp.Column(this="rowid", table=source_tbl),
expression=exp.Column(this="source_rowid"),
)
)
)
)
# Query finding rows that match the original ON condition and this WHEN condition
subquery_ignoring_temp_table = exp.Exists(
this=exp.select(exp.Literal.number(1)).from_(target_tbl).where(on_expr)
)
subquery = exp.And(this=subquery_ignoring_temp_table, expression=not_in_temp_table_subquery)

not_exists_subquery = exp.Not(this=subquery)
temp_match_where = (
exp.And(this=and_condition, expression=not_exists_subquery) if and_condition else not_exists_subquery
)
temp_match_expr = exp.insert(
into=TEMP_MERGE_INSERTS,
expression=exp.select("rowid", exp.Literal.number(w_idx)).from_(source_tbl).where(temp_match_where),
)
temp_table_inserts.append(temp_match_expr)

rowid_in_temp_table_expr = exp.In(
this=exp.Column(this="rowid", table=source_tbl),
expressions=[
exp.select("source_rowid")
.from_(TEMP_MERGE_INSERTS)
.where(exp.EQ(this="when_id", expression=exp.Literal.number(w_idx)))
.where(exp.EQ(this="source_rowid", expression=exp.Column(this="rowid", table=source_tbl)))
],
)

this_expr = then.args.get("this")
assert this_expr is not None, "this_expr is None"
then_exprs = then.args.get("expression")
assert then_exprs is not None, "then_exprs is None"
columns = [_remove_table_alias(e) for e in this_expr.expressions]
statement = exp.insert(
into=target_tbl,
columns=[c.this for c in columns],
expression=exp.select(*(then_exprs.args.get("expressions")))
.from_(source_tbl)
.where(rowid_in_temp_table_expr),
)
output_expressions.append(statement)

return _generate_final_expression_set(temp_table_inserts, output_expressions)
Loading