Skip to content

Commit

Permalink
feat: Adds MERGE INTO transform
Browse files Browse the repository at this point in the history
This commit adds the ability to convert snowflakes [MERGE INTO](https://docs.snowflake.com/en/sql-reference/sql/merge) functionality into a functional equivalent implementation in duckdb.
To do this we need to break apart the WHEN [NOT] MATCHED syntax into separate statements to be executed indepedently.

This commit only adds the transform, there is more refactoring required in fakes.py in order to handle a transform
that transforms a single expression into multiple expressions.
  • Loading branch information
jsibbison-square committed Jun 27, 2024
1 parent 0eafd8c commit b7b996b
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 0 deletions.
63 changes: 63 additions & 0 deletions fakesnow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,69 @@ def json_extract_precedence(expression: exp.Expression) -> exp.Expression:
return expression


def merge(expression: exp.Expression) -> list[exp.Expression]:
"""Create multiple compatible duckdb statements to be functionally equivalent to Snowflake's MERGE INTO.
Snowflake's MERGE INTO: See https://docs.snowflake.com/en/sql-reference/sql/merge.html
"""

if isinstance(expression, exp.Merge):
output_expressions = []
target_table = expression.this
source_table = expression.args.get("using")
on_expression = expression.args.get("on")
whens = expression.expressions
for w in whens:
assert isinstance(w, exp.When), f"Expected When expression, got {w}"

and_condition = w.args.get("condition")
subquery_on_expression = on_expression.copy()
if and_condition:
subquery_on_expression = exp.And(this=subquery_on_expression, expression=and_condition)
subquery = exp.Exists(
this=exp.Select(expressions=[exp.Star()])
.from_(source_table)
.join(target_table, on=subquery_on_expression)
)

matched = w.args.get("matched")
then = w.args.get("then")
if matched:
if isinstance(then, exp.Update):

def remove_source_alias(eq_exp: exp.EQ) -> exp.EQ:
eq_exp.args.get("this").set("table", None)
eq_exp.set("this", exp.Column(this=eq_exp.args.get("this"), table=None))
return eq_exp

then.set("this", target_table)
then.set(
"expressions",
exp.Set(expressions=[remove_source_alias(e) for e in then.args.get("expressions")]),
)
then.set("from", exp.From(this=source_table))
then.set("where", exp.Where(this=subquery))
output_expressions.append(then)
elif then.args.get("this") == "DELETE": # Var(this=DELETE) when processing WHEN MATCHED THEN DELETE.
output_expressions.append(exp.Delete(this=target_table).where(subquery))
else:
assert isinstance(then, (exp.Update, exp.Delete)), f"Expected 'Update' or 'Delete', got {then}"
else:
assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}"
not_exists_subquery = exp.Not(this=subquery)

statement = exp.Insert(
this=exp.Schema(this=target_table, expressions=then.args.get("this").expressions),
expression=exp.Select()
.select(*(then.args.get("expression").args.get("expressions")))
.from_(source_table)
.where(not_exists_subquery),
)
output_expressions.append(statement)
return output_expressions
else:
return [expression]


def random(expression: exp.Expression) -> exp.Expression:
"""Convert random() and random(seed).
Expand Down
91 changes: 91 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
json_extract_cased_as_varchar,
json_extract_cast_as_varchar,
json_extract_precedence,
merge,
object_construct,
random,
regex_replace,
Expand Down Expand Up @@ -550,6 +551,96 @@ def test_json_extract_precedence() -> None:
)


def test_merge_update_insert() -> None:
expression = sqlglot.parse_one("""
MERGE INTO table1 AS T
USING table2 AS S
ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo
WHEN MATCHED THEN
UPDATE SET T.name = S.name, T.version = S.version
WHEN NOT MATCHED THEN
INSERT (id, name) VALUES (S.id, S.name)
""")

expressions = merge(expression)
assert len(expressions) == 2
assert (
expressions[0].sql()
== "UPDATE table1 AS T SET name = S.name, version = S.version FROM table2 AS S WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)
assert (
expressions[1].sql()
== "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)


def test_merge_update_insert_and() -> None:
expression = sqlglot.parse_one("""
MERGE INTO table1 AS T
USING table2 AS S
ON T.id = S.id AND T.blah = S.blah
WHEN MATCHED AND T.foo = S.foo THEN
UPDATE SET T.name = S.name, T.version = S.version
WHEN NOT MATCHED AND T.foo = S.foo THEN
INSERT (id, name) VALUES (S.id, S.name)
""")

expressions = merge(expression)
assert len(expressions) == 2
assert (
expressions[0].sql()
== "UPDATE table1 AS T SET name = S.name, version = S.version FROM table2 AS S WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)
assert (
expressions[1].sql()
== "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)


def test_merge_delete_insert() -> None:
expression = sqlglot.parse_one("""
MERGE INTO table1 AS T
USING table2 AS S
ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo
WHEN MATCHED THEN DELETE
WHEN NOT MATCHED THEN
INSERT (id, name) VALUES (S.id, S.name)
""")

expressions = merge(expression)
assert len(expressions) == 2
assert (
expressions[0].sql()
== "DELETE FROM table1 AS T WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)
assert (
expressions[1].sql()
== "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)


def test_merge_delete_insert_and() -> None:
expression = sqlglot.parse_one("""
MERGE INTO table1 AS T
USING table2 AS S
ON T.id = S.id AND T.blah = S.blah
WHEN MATCHED AND T.foo = S.foo THEN DELETE
WHEN NOT MATCHED AND T.foo = S.foo THEN
INSERT (id, name) VALUES (S.id, S.name)
""")

expressions = merge(expression)
assert len(expressions) == 2
assert (
expressions[0].sql()
== "DELETE FROM table1 AS T WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)
assert (
expressions[1].sql()
== "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501
)


def test_object_construct() -> None:
assert (
sqlglot.parse_one(
Expand Down

0 comments on commit b7b996b

Please sign in to comment.