From 9bf55c213dd5e054815a08c35b587e74640630bb Mon Sep 17 00:00:00 2001 From: Alexey Romanov Date: Thu, 15 Dec 2016 14:57:37 +0300 Subject: [PATCH] Rewrite away vacuous sort/partialSort immediately (fixes #8) --- .../src/main/scala/scalan/sql/Iters.scala | 39 ++++++++++++++++++- .../scala/scalan.sql/SqlBridgeTests.scala | 6 +++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/scalan-sql-core/src/main/scala/scalan/sql/Iters.scala b/scalan-sql-core/src/main/scala/scalan/sql/Iters.scala index a310caa..b89384e 100644 --- a/scalan-sql-core/src/main/scala/scalan/sql/Iters.scala +++ b/scalan-sql-core/src/main/scala/scalan/sql/Iters.scala @@ -82,7 +82,7 @@ trait Iters extends ScalanDsl { def partialSort(prefixComparator: Rep[((Row, Row)) => Boolean], suffixComparator: Rep[((Row, Row)) => Boolean]): RIter[Row] = delayInvoke - // if `leftIsOuter` is true, `other` will be hashed; otherwise, `this` will be + // when adding joinType make sure hasAtMostOneRow checks it's Inner def join[B, Key](other: RIter[B], thisKey: Rep[Row => Key], otherKey: Rep[B => Key], cloneOther: Rep[B => B]/*, joinType: JoinType*/): RIter[(Row, B)] = delayInvoke def toArray: Arr[Row] = delayInvoke @@ -371,6 +371,36 @@ trait ItersDslExp extends impl.ItersExp { self: ScalanSqlExp => super.getResultElem(receiver, m, args) } + def hasAtMostOneRow(dIterOrRelation: Def[_]): Boolean = dIterOrRelation match { + case _: SingletonIter[_] | _: EmptyIter[_] => + true + case IterMethods.reduce(_, _, _) => + true + case MethodCall(_, m, _, _) if m.getName.startsWith("uniqueBy") => + true + // TODO add case for unique Scannable#search (maybe it should be a separate method) + case MethodCall(iterOrRelation, m, _, _) if { + val name = m.getName + name == "map" || name == "filter" || name == "takeWhile" || name == "mapReduce" || + name == "partialMapReduce" + } => + hasAtMostOneRow(iterOrRelation) + case ExpConditionalIter(_, iterOrRelation) => + hasAtMostOneRow(iterOrRelation) + case IterMethods.join(iterOrRelation1, iterOrRelation2, _, _, _) => + hasAtMostOneRow(iterOrRelation1) && hasAtMostOneRow(iterOrRelation2) + // this case could go into Relations.scala, but no point splitting like this + case RelationMethods.hashJoin(iterOrRelation1, iterOrRelation2, _, _, _) => + hasAtMostOneRow(iterOrRelation1) && hasAtMostOneRow(iterOrRelation2) + case _ => false + } + + def hasAtMostOneRow(iterOrRelation: Exp[_]): Boolean = iterOrRelation match { + case Def(dIterOrRelation) => + hasAtMostOneRow(dIterOrRelation) + case _ => false + } + // hacky, but should be equivalent to building the lambda using `fun` as normal def copyLambda[A, B, C](l: Lambda[A, B], v: Rep[C]): Rep[A => C] = { val x = l.x @@ -544,6 +574,13 @@ trait ItersDslExp extends impl.ItersExp { self: ScalanSqlExp => case _ => super.rewriteDef(d) } + case MethodCall(receiver, m, _, _) if { + val name = m.getName + (name == "sort" || name == "sortBy" || name == "partialSort") && + hasAtMostOneRow(receiver) + } => + receiver + case TableIterMethods.byRowids(iter, Def(ExpSingletonIter(value: Rep[a])), f) => iter.uniqueByRowid(f.asRep[a => Rowid](value)) case TableIterMethods.byRowids(iter, Def(ExpConditionalIter(c, baseIter: RIter[a] @unchecked)), f) => diff --git a/scalan-sql-core/src/test/scala/scalan.sql/SqlBridgeTests.scala b/scalan-sql-core/src/test/scala/scalan.sql/SqlBridgeTests.scala index a404e2d..f8cbbdb 100644 --- a/scalan-sql-core/src/test/scala/scalan.sql/SqlBridgeTests.scala +++ b/scalan-sql-core/src/test/scala/scalan.sql/SqlBridgeTests.scala @@ -154,6 +154,12 @@ abstract class AbstractSqlBridgeTests extends BaseNestedTests { describe("vacuous order by") { // queries where results have at most one row and so shouldn't contain sort + it("rowid") { + // TODO currently o_custkey is read from the table even though it won't be used in the end, + // but fixing this is non-trivial and low priority + testQuery("SELECT o_totalprice FROM orders WHERE o_orderkey = 1000 ORDER BY o_custkey") + } + it("join and aggregate") { pendingUntilFixed { // there is currently a resolution error