Skip to content

Commit

Permalink
feat(spark): add support for ‘offset’ clause
Browse files Browse the repository at this point in the history
Add missing support for the ‘offset’ clause in the spark module.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Oct 3, 2024
1 parent eec9727 commit 1e80765
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{DataTypes, IntegerType, LongType, StructField, StructType}
import io.substrait.`type`.{StringTypeVisitor, Type}
import io.substrait.{expression => exp}
import io.substrait.expression.{Expression => SExpression}
Expand Down Expand Up @@ -160,11 +160,19 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}
override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType)
fetch.getOffset match {
case 1L => GlobalLimit(limitExpr = limit, child = child)
case -1L => LocalLimit(limitExpr = limit, child = child)
case _ => visitFallback(fetch)
val limit = fetch.getCount.getAsLong.intValue()
val offset = fetch.getOffset.intValue()
if (limit >= 0) {
val limitExpr = Literal(limit, IntegerType)
if (offset > 0) {
GlobalLimit(limitExpr,
Offset(Literal(offset, IntegerType),
LocalLimit(Literal(offset + limit, IntegerType), child)))
} else {
GlobalLimit(limitExpr, LocalLimit(limitExpr, child))
}
} else {
Offset(Literal(offset, IntegerType), child)
}
}
override def visit(sort: relation.Sort): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,23 +170,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
case other => throw new UnsupportedOperationException(s"Unknown type: $other")
}

private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = {
val offset = if (global) 1L else -1L
relation.Fetch
.builder()
.count(limit)
private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = {
relation.Fetch.builder()
.input(visit(child))
.offset(offset)
.count(limit)
.build()
}

override def visitGlobalLimit(p: GlobalLimit): relation.Rel = {
fetchBuilder(asLong(p.limitExpr), global = true)
.input(visit(p.child))
.build()
p match {
case OffsetAndLimit((offset, limit, child)) => fetch(child, offset, limit)
case GlobalLimit(IntegerLiteral(globalLimit), LocalLimit(IntegerLiteral(localLimit), child))
if globalLimit == localLimit => fetch(child, 0, localLimit)
case _ =>
throw new UnsupportedOperationException(s"Unable to convert the limit expression: $p")
}
}

override def visitLocalLimit(p: LocalLimit): relation.Rel = {
fetchBuilder(asLong(p.limitExpr), global = false)
.input(visit(p.child))
.build()
val localLimit = asLong(p.limitExpr)
p.child match {
case OffsetAndLimit((offset, limit, child)) if localLimit >= limit =>
fetch(child, offset, limit)
case GlobalLimit(IntegerLiteral(globalLimit), child) if localLimit >= globalLimit =>
fetch(child, 0, globalLimit)
case _ => fetch(p.child, 0, localLimit)
}
}

override def visitOffset(p: Offset): relation.Rel = {
fetch(p.child, asLong(p.offsetExpr))
}

override def visitFilter(p: Filter): relation.Rel = {
Expand Down
8 changes: 7 additions & 1 deletion spark/src/test/scala/io/substrait/spark/TPCHPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,16 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase {
"order by l_shipdate asc, l_discount desc nulls last")
}

ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass
test("simpleOffsetClause") {
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc limit 100 offset 1000")
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc offset 1000")
assertSqlSubstraitRelRoundTrip(
"select l_partkey from lineitem where l_shipdate < date '1998-01-01' " +
"order by l_shipdate asc, l_discount desc limit 100")
}

test("simpleTest") {
Expand Down

0 comments on commit 1e80765

Please sign in to comment.