Skip to content

Commit

Permalink
feat(spark): add support for ‘offset’ clause
Browse files Browse the repository at this point in the history
Add missing support for the ‘offset’ clause in the spark module.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Oct 4, 2024
1 parent eec9727 commit 3498d44
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,20 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType)
fetch.getOffset match {
case 1L => GlobalLimit(limitExpr = limit, child = child)
case -1L => LocalLimit(limitExpr = limit, child = child)
case _ => visitFallback(fetch)
val limit = fetch.getCount.getAsLong.intValue()
val offset = fetch.getOffset.intValue()
val toLiteral = (i: Int) => Literal(i, IntegerType)
if (limit >= 0) {
val limitExpr = toLiteral(limit)
if (offset > 0) {
GlobalLimit(limitExpr,
Offset(toLiteral(offset),
LocalLimit(toLiteral(offset + limit), child)))
} else {
GlobalLimit(limitExpr, LocalLimit(limitExpr, child))
}
} else {
Offset(toLiteral(offset), child)
}
}
override def visit(sort: relation.Sort): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
case other => throw new UnsupportedOperationException(s"Unknown type: $other")
}

private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = {
val offset = if (global) 1L else -1L
relation.Fetch
.builder()
.count(limit)
private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = {
relation.Fetch.builder()
.input(visit(child))
.offset(offset)
.count(limit)
.build()
}

override def visitGlobalLimit(p: GlobalLimit): relation.Rel = {
fetchBuilder(asLong(p.limitExpr), global = true)
.input(visit(p.child))
.build()
p match {
case OffsetAndLimit((offset, limit, child)) => fetch(child, offset, limit)
case GlobalLimit(IntegerLiteral(globalLimit), LocalLimit(IntegerLiteral(localLimit), child))
if globalLimit == localLimit => fetch(child, 0, localLimit)
case _ =>
throw new UnsupportedOperationException(s"Unable to convert the limit expression: $p")
}
}

override def visitLocalLimit(p: LocalLimit): relation.Rel = {
fetchBuilder(asLong(p.limitExpr), global = false)
.input(visit(p.child))
.build()
val localLimit = asLong(p.limitExpr)
p.child match {
case OffsetAndLimit((offset, limit, child)) if localLimit >= limit =>
fetch(child, offset, limit)
case GlobalLimit(IntegerLiteral(globalLimit), child) if localLimit >= globalLimit =>
fetch(child, 0, globalLimit)
case _ => fetch(p.child, 0, localLimit)
}
}

override def visitOffset(p: Offset): relation.Rel = {
fetch(p.child, asLong(p.offsetExpr))
}

override def visitFilter(p: Filter): relation.Rel = {
Expand Down
8 changes: 7 additions & 1 deletion spark/src/test/scala/io/substrait/spark/TPCHPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
"order by l_shipdate asc, l_discount desc nulls last")
}

ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass
test("simpleOffsetClause") {
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc limit 100 offset 1000")
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc offset 1000")
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc limit 100")
}

test("simpleTest") {
Expand Down

0 comments on commit 3498d44

Please sign in to comment.