Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNAP-2315] Added support for CTEs in update/delete. #1012

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,59 @@ class SnappySQLQuerySuite extends SnappyFunSuite {
"(exists (select col1 from r2 where r2.col1=r1.col1) " +
"or exists(select col1 from r3 where r3.col1=r1.col1))")

val result = df.collect()
checkAnswer(df, Seq(Row(1, "1", "1", 100),
Row(2, "2", "2", 2), Row(4, "4", "4", 4) ))
snc.dropTable("r1", ifExists = true)
}

test("Delete duplicate with WITH and window function") {
val snc = new SnappySession(sc)
snc.dropTable("r1", ifExists = true)
snc.sql("create table r1(col1 INT, col2 STRING, col3 String, col4 Int)" +
" using column ")


snc.insert("r1", Row(1, "1", "1", 100))
snc.insert("r1", Row(1, "1", "1", 100))
snc.insert("r1", Row(2, "4", "4", 4))
snc.insert("r1", Row(2, "4", "4", 4))
snc.sql("WITH dups AS " +
"(SELECT col1, ROW_NUMBER() OVER" +
" (PARTITION BY col1 ORDER BY ( SELECT 0))" +
" RN FROM r1) DELETE from dups where rn > 1;")

val df = snc.sql("Select * from r1")
checkAnswer(df, Seq(Row(1, "1", "1", 100),
Row(2, "4", "4", 4)))
}

test("Update rows duplicate with WITH and window function") {
val snc = new SnappySession(sc)
snc.dropTable("r1", ifExists = true)
snc.sql("create table r1(col1 INT, col2 STRING, col3 String, col4 Int)" +
" using column ")


snc.insert("r1", Row(1, "1", "1", 100))
snc.insert("r1", Row(1, "1", "1", 100))
snc.insert("r1", Row(2, "4", "4", 4))
snc.insert("r1", Row(2, "4", "4", 4))
snc.sql("WITH dups AS " +
"(SELECT col1, ROW_NUMBER() OVER" +
" (PARTITION BY col1 ORDER BY ( SELECT 0))" +
" RN FROM r1) update dups set col1 = 99 where rn > 1;")

val df = snc.sql("Select * from r1")
checkAnswer(df, Seq(Row(99, "1", "1", 100),
Row(99, "4", "4", 4), Row(1, "1", "1", 100), Row(2, "4", "4", 4)))
}

test("netsed CTEs") {
val snc = new SnappySession(sc)
val df = snc.sql("select * from range(10) where id" +
" not in (select id from range(2) union all select id from range(2))")

df.show
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ class SnappyParser(session: SnappySession)
protected final def ctes: Rule1[LogicalPlan] = rule {
WITH ~ ((identifier ~ AS.? ~ '(' ~ ws ~ query ~ ')' ~ ws ~>
((id: String, p: LogicalPlan) => (id, p))) + commaSep) ~
(query | insert) ~> ((r: Seq[(String, LogicalPlan)], s: LogicalPlan) =>
(query | insert | delete | update) ~> ((r: Seq[(String, LogicalPlan)], s: LogicalPlan) =>
With(s, r.map(ns => (ns._1, SubqueryAlias(ns._1, ns._2, None)))))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,25 @@ class SnappySessionState(snappySession: SnappySession)
}
}

def projectKeyAttributes(table: LogicalPlan,
newChild: LogicalPlan,
keyAttrs: Seq[NamedExpression]): (LogicalPlan, LogicalPlan) = {
val transformedChild = newChild.transformUp {
case Project(attr, ch)
if keyAttrs.forall(k => ch.output.map(_.name).contains(k.name)) =>
Project(attr ++ keyAttrs, ch)
}
val physicalTables = table.collect {
case lr@LogicalRelation(mutable: MutableRelation, _, _) => lr
}
if (physicalTables.size > 1 || physicalTables.isEmpty) {
throw new AnalysisException("You need to update/delete on one and only one mutable table." +
" If you are using a subquery/CTE in the FROM clause ensure" +
" it is only on one mutable relation")
}
(physicalTables.head, transformedChild)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case c: DMLExternalTable if !c.query.resolved =>
c.copy(query = analyzeQuery(c.query))
Expand Down Expand Up @@ -469,9 +488,15 @@ class SnappySessionState(snappySession: SnappySession)
// any extra columns
val allReferences = newChild.references ++
AttributeSet(newUpdateExprs.flatMap(_.references)) ++ AttributeSet(keyAttrs)
u.copy(child = Project(newChild.output.filter(allReferences.contains), newChild),

val (physicalTable, transformedChild) = projectKeyAttributes(table, newChild, keyAttrs)

u.copy(child = Project(transformedChild.output.filter(allReferences.contains),
transformedChild),
keyColumns = keyAttrs.map(_.toAttribute),
updateColumns = updateAttrs.map(_.toAttribute), updateExpressions = newUpdateExprs)
updateColumns = updateAttrs.map(_.toAttribute),
updateExpressions = newUpdateExprs,
table = physicalTable)
}

case d@Delete(table, child, keyColumns) if keyColumns.isEmpty && child.resolved =>
Expand All @@ -480,7 +505,8 @@ class SnappySessionState(snappySession: SnappySession)
// if this is a row table with no PK, then fallback to direct execution
if (keyAttrs.isEmpty) newChild
else {
d.copy(child = Project(keyAttrs, newChild),
val (physicalTable, transformedChild) = projectKeyAttributes(table, newChild, keyAttrs)
d.copy(table = physicalTable, child = Project(keyAttrs, transformedChild),
keyColumns = keyAttrs.map(_.toAttribute))
}
case d@DeleteFromTable(_, child) if child.resolved =>
Expand Down