diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 45b6c2205..ca1fbe25b 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{DataTypes, IntegerType, LongType, StructField, StructType} import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} @@ -160,11 +160,19 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = Literal(fetch.getCount.getAsLong.intValue(), IntegerType) - fetch.getOffset match { - case 1L => GlobalLimit(limitExpr = limit, child = child) - case -1L => LocalLimit(limitExpr = limit, child = child) - case _ => visitFallback(fetch) + val limit = fetch.getCount.getAsLong.intValue() + val offset = fetch.getOffset.intValue() + if (limit >= 0) { + val limitExpr = Literal(limit, IntegerType) + if (offset > 0) { + GlobalLimit(limitExpr, + Offset(Literal(offset, IntegerType), + LocalLimit(Literal(offset + limit, IntegerType), child))) + } else { + GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) + } + } else { + Offset(Literal(offset, IntegerType), child) } } override def visit(sort: relation.Sort): LogicalPlan = { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 08a06c2e4..e66126bb4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -170,23 +170,37 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { case other => throw new UnsupportedOperationException(s"Unknown type: $other") } - private def fetchBuilder(limit: Long, global: Boolean): relation.ImmutableFetch.Builder = { - val offset = if (global) 1L else -1L - relation.Fetch - .builder() - .count(limit) + private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { + relation.Fetch.builder() + .input(visit(child)) .offset(offset) + .count(limit) + .build() } + override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { - fetchBuilder(asLong(p.limitExpr), global = true) - .input(visit(p.child)) - .build() + p match { + case OffsetAndLimit((offset, limit, child)) => fetch(child, offset, limit) + case GlobalLimit(IntegerLiteral(globalLimit), LocalLimit(IntegerLiteral(localLimit), child)) + if globalLimit == localLimit => fetch(child, 0, localLimit) + case _ => + throw new UnsupportedOperationException(s"Unable to convert the limit expression: $p") + } } override def visitLocalLimit(p: LocalLimit): relation.Rel = { - fetchBuilder(asLong(p.limitExpr), global = false) - .input(visit(p.child)) - .build() + val localLimit = asLong(p.limitExpr) + p.child match { + case OffsetAndLimit((offset, limit, child)) if localLimit >= limit => + fetch(child, offset, limit) + case GlobalLimit(IntegerLiteral(globalLimit), child) if localLimit >= globalLimit => + fetch(child, 0, globalLimit) + case _ => fetch(p.child, 0, localLimit) + } + } + + override def visitOffset(p: Offset): relation.Rel = { + fetch(p.child, asLong(p.offsetExpr)) } override def visitFilter(p: Filter): relation.Rel = { diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index 224ac2e8d..76c5b9a6a 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -73,10 +73,16 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "order by l_shipdate asc, l_discount desc nulls last") } - ignore("simpleOffsetClause") { // TODO need to implement the 'offset' clause for this to pass + test("simpleOffsetClause") { assertSqlSubstraitRelRoundTrip( "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + "order by l_shipdate asc, l_discount desc limit 100 offset 1000") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc offset 1000") + assertSqlSubstraitRelRoundTrip( + "select l_partkey from lineitem where l_shipdate < date '1998-01-01' " + + "order by l_shipdate asc, l_discount desc limit 100") } test("simpleTest") {