From 6a187341197cc0c145f12104f14d88b531de28e6 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:15:47 +0100 Subject: [PATCH 01/10] fix: converting proto to pojo should take into account join type for column matching --- core/src/main/java/io/substrait/relation/Join.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index c3f9387b3..403663471 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -59,6 +59,7 @@ protected Type.Struct deriveRecordType() { switch (getJoinType()) { case LEFT, OUTER -> getRight().getRecordType().fields().stream() .map(TypeCreator::asNullable); + case SEMI, ANTI -> Stream.of(); // these are left joins which ignore right side columns default -> getRight().getRecordType().fields().stream(); }; return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); From c56197e9ddc212c59761df5d5b457cc0a503e0bb Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:16:39 +0100 Subject: [PATCH 02/10] fix: support treestring for VirtualTableScan --- .../io/substrait/debug/RelToVerboseString.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 79b34462f..0ef4a5310 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -100,6 +100,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { builder.append("commonExtension=").append(commonExtension) }) } + override def visit(namedScan: NamedScan): String = { withBuilder(namedScan, 10)( builder => { @@ -115,6 +116,21 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(virtualTableScan: VirtualTableScan): String = { + withBuilder(virtualTableScan, 10)( + builder => { + fillReadRel(virtualTableScan, builder) + builder.append(", ") + builder.append("rows=").append(virtualTableScan.getRows) + + virtualTableScan.getExtension.ifPresent( + extension => { + builder.append(", ") + builder.append("extension=").append(extension) + }) + }) + } + override def visit(emptyScan: EmptyScan): String = { withBuilder(emptyScan, 10)( builder => { From 351c6b663f1d4f1236b838c5deb0c493bee63ae6 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:16:57 +0100 Subject: [PATCH 03/10] fix: correctly set nullability for aggregate references --- .../main/scala/io/substrait/spark/logical/ToSubstraitRel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 2827a9c30..605d33b3a 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -131,7 +131,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) val aggOutputMap = aggregates.zipWithIndex.map { case (e, i) => - AttributeReference(s"agg_func_$i", e.dataType)() -> e + AttributeReference(s"agg_func_$i", e.dataType, nullable = e.nullable)() -> e } val aggOutput = aggOutputMap.map(_._1) From 681a02e104793edef79a37dbe5f8565b1a6aec45 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 17:17:20 +0100 Subject: [PATCH 04/10] fix: handle nested aliases for window functions --- .../io/substrait/spark/logical/ToSubstraitRel.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 605d33b3a..3ba86c7a2 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -185,13 +185,10 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } override def visitWindow(window: Window): relation.Rel = { - val windowExpressions = window.windowExpressions.map { - case w: WindowExpression => fromWindowCall(w, window.child.output) - case a: Alias if a.child.isInstanceOf[WindowExpression] => - fromWindowCall(a.child.asInstanceOf[WindowExpression], window.child.output) - case other => - throw new UnsupportedOperationException(s"Unsupported window expression: $other") - }.asJava + val windowExpressions = window.windowExpressions + .flatMap(expr => expr.collect { case w: WindowExpression => w }) + .map(fromWindowCall(_, window.child.output)) + .asJava val partitionExpressions = window.partitionSpec.map(toExpression(window.child.output)).asJava val sorts = window.orderSpec.map(toSortField(window.child.output)).asJava From a7effe2378a67366875164561db447fc0d5e04ad Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 21:21:08 -0400 Subject: [PATCH 05/10] fix: correctly set nullability for aggregate grouping exprs --- .../main/scala/io/substrait/spark/logical/ToSubstraitRel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 3ba86c7a2..695f93919 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -145,7 +145,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } val groupOutputMap = actualGroupExprs.zipWithIndex.map { case (e, i) => - AttributeReference(s"group_col_$i", e.dataType)() -> e + AttributeReference(s"group_col_$i", e.dataType, nullable = e.nullable)() -> e } val groupOutput = groupOutputMap.map(_._1) From 591848ae7f1cd3f4fd3591cf5f4edd2f7a9cde7a Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Sun, 27 Oct 2024 21:23:51 -0400 Subject: [PATCH 06/10] fix: correctly set type for scalar subquery when converting proto to pojo --- .../expression/proto/ProtoExpressionConverter.java | 11 ++++++++++- .../substrait/type/proto/ExtensionRoundtripTest.java | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index d2b95d74f..8857614f8 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -15,6 +15,7 @@ import io.substrait.relation.ConsistentPartitionWindow; import io.substrait.relation.ProtoRelConverter; import io.substrait.type.Type; +import io.substrait.type.TypeVisitor; import io.substrait.type.proto.ProtoTypeConverter; import java.util.ArrayList; import java.util.Collections; @@ -196,7 +197,15 @@ public Expression from(io.substrait.proto.Expression expr) { var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); yield ImmutableExpression.ScalarSubquery.builder() .input(rel) - .type(rel.getRecordType()) + .type(rel.getRecordType().accept(new TypeVisitor.TypeThrowsVisitor("Expected struct field") { + @Override + public Type visit(Type.Struct type) throws RuntimeException { + if (type.fields().size() != 1) { + throw new UnsupportedOperationException("Scalar subquery must have exactly one field"); + } + return type.fields().get(0); + } + })) .build(); } case IN_PREDICATE -> { diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index eb2c05b29..7b2237661 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -312,7 +312,7 @@ void scalarSubquery() { Stream.of( Expression.ScalarSubquery.builder() .input(relWithEnhancement) - .type(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64)) + .type(TypeCreator.REQUIRED.I64) .build()) .collect(Collectors.toList()), commonTable); From e374c85907c48227826402b84246a91cdec0c25c Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:27:02 +0100 Subject: [PATCH 07/10] fix: spotless --- .../proto/ProtoExpressionConverter.java | 23 +++++++++++-------- .../spark/logical/ToSubstraitRel.scala | 4 ++-- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 8857614f8..3120c8941 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -197,15 +197,20 @@ public Expression from(io.substrait.proto.Expression expr) { var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput()); yield ImmutableExpression.ScalarSubquery.builder() .input(rel) - .type(rel.getRecordType().accept(new TypeVisitor.TypeThrowsVisitor("Expected struct field") { - @Override - public Type visit(Type.Struct type) throws RuntimeException { - if (type.fields().size() != 1) { - throw new UnsupportedOperationException("Scalar subquery must have exactly one field"); - } - return type.fields().get(0); - } - })) + .type( + rel.getRecordType() + .accept( + new TypeVisitor.TypeThrowsVisitor( + "Expected struct field") { + @Override + public Type visit(Type.Struct type) throws RuntimeException { + if (type.fields().size() != 1) { + throw new UnsupportedOperationException( + "Scalar subquery must have exactly one field"); + } + return type.fields().get(0); + } + })) .build(); } case IN_PREDICATE -> { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 695f93919..3451c1166 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -186,8 +186,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { override def visitWindow(window: Window): relation.Rel = { val windowExpressions = window.windowExpressions - .flatMap(expr => expr.collect { case w: WindowExpression => w }) - .map(fromWindowCall(_, window.child.output)) + .flatMap(expr => expr.collect { case w: WindowExpression => w }) + .map(fromWindowCall(_, window.child.output)) .asJava val partitionExpressions = window.partitionSpec.map(toExpression(window.child.output)).asJava From ab104f59019039eeec25e2b6b3310afdc381aee4 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:48:42 +0100 Subject: [PATCH 08/10] fix: add assert to check the pojo-proto-roundtrip --- .../test/scala/io/substrait/spark/SubstraitPlanTestBase.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index cbd7a151c..dea8f3a53 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -95,7 +95,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => val extensionCollector = new ExtensionCollector; val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + val pojoFromProto = new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + assertResult(pojoRel)(pojoFromProto) pojoRel2.shouldEqualPlainly(pojoRel) logicalPlan2 From 05ce6034be395d779cdfcfb10c2ef51aea92aff8 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:49:46 +0100 Subject: [PATCH 09/10] fix: handle fetch's count in a way that matches roundtrip --- .../scala/io/substrait/spark/logical/ToLogicalPlan.scala | 2 +- .../io/substrait/spark/logical/ToSubstraitRel.scala | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 68b15345a..278233ee4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -215,7 +215,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] override def visit(fetch: relation.Fetch): LogicalPlan = { val child = fetch.getInput.accept(this) - val limit = fetch.getCount.getAsLong.intValue() + val limit = fetch.getCount.orElse(-1).intValue() val offset = fetch.getOffset.intValue() val toLiteral = (i: Int) => Literal(i, IntegerType) if (limit >= 0) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 3451c1166..5cf4cbc97 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -206,12 +206,15 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { } private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = { - relation.Fetch + val builder = relation.Fetch .builder() .input(visit(child)) .offset(offset) - .count(limit) - .build() + if (limit != -1) { + builder.count(limit) + } + + builder.build() } override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { From 13a3b99022d22fb822a4bce997baf5ae004c5914 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Thu, 21 Nov 2024 18:50:07 +0100 Subject: [PATCH 10/10] fix: spotless --- .../main/scala/io/substrait/spark/logical/ToSubstraitRel.scala | 2 +- .../test/scala/io/substrait/spark/SubstraitPlanTestBase.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 5cf4cbc97..64fa085e6 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -214,7 +214,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { builder.count(limit) } - builder.build() + builder.build() } override def visitGlobalLimit(p: GlobalLimit): relation.Rel = { diff --git a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala index dea8f3a53..0571ee544 100644 --- a/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala +++ b/spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala @@ -95,7 +95,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession => val extensionCollector = new ExtensionCollector; val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel) - val pojoFromProto = new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) + val pojoFromProto = + new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto) assertResult(pojoRel)(pojoFromProto) pojoRel2.shouldEqualPlainly(pojoRel)