Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] fix: add proto roundtrips for Spark tests and fix issues it surfaces #315

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import io.substrait.relation.ConsistentPartitionWindow;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.type.Type;
import io.substrait.type.TypeVisitor;
import io.substrait.type.proto.ProtoTypeConverter;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -196,7 +197,20 @@ public Expression from(io.substrait.proto.Expression expr) {
var rel = protoRelConverter.from(expr.getSubquery().getScalar().getInput());
yield ImmutableExpression.ScalarSubquery.builder()
.input(rel)
.type(rel.getRecordType())
.type(
rel.getRecordType()
.accept(
new TypeVisitor.TypeThrowsVisitor<Type, RuntimeException>(
"Expected struct field") {
@Override
public Type visit(Type.Struct type) throws RuntimeException {
if (type.fields().size() != 1) {
throw new UnsupportedOperationException(
"Scalar subquery must have exactly one field");
}
return type.fields().get(0);
}
}))
.build();
}
case IN_PREDICATE -> {
Expand Down
1 change: 1 addition & 0 deletions core/src/main/java/io/substrait/relation/Join.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ protected Type.Struct deriveRecordType() {
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case SEMI, ANTI -> Stream.of(); // these are left joins which ignore right side columns
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ void scalarSubquery() {
Stream.of(
Expression.ScalarSubquery.builder()
.input(relWithEnhancement)
.type(TypeCreator.REQUIRED.struct(TypeCreator.REQUIRED.I64))
.type(TypeCreator.REQUIRED.I64)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure about this, is it actually meant to return a struct type? given it's scalar that seems a bit weird

.build())
.collect(Collectors.toList()),
commonTable);
Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
builder.append("commonExtension=").append(commonExtension)
})
}

override def visit(namedScan: NamedScan): String = {
withBuilder(namedScan, 10)(
builder => {
Expand All @@ -115,6 +116,21 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
})
}

override def visit(virtualTableScan: VirtualTableScan): String = {
withBuilder(virtualTableScan, 10)(
builder => {
fillReadRel(virtualTableScan, builder)
builder.append(", ")
builder.append("rows=").append(virtualTableScan.getRows)

virtualTableScan.getExtension.ifPresent(
extension => {
builder.append(", ")
builder.append("extension=").append(extension)
})
})
}

override def visit(emptyScan: EmptyScan): String = {
withBuilder(emptyScan, 10)(
builder => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]

override def visit(fetch: relation.Fetch): LogicalPlan = {
val child = fetch.getInput.accept(this)
val limit = fetch.getCount.getAsLong.intValue()
val limit = fetch.getCount.orElse(-1).intValue()
val offset = fetch.getOffset.intValue()
val toLiteral = (i: Int) => Literal(i, IntegerType)
if (limit >= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
val aggOutputMap = aggregates.zipWithIndex.map {
case (e, i) =>
AttributeReference(s"agg_func_$i", e.dataType)() -> e
AttributeReference(s"agg_func_$i", e.dataType, nullable = e.nullable)() -> e
Copy link
Contributor Author

@Blizzara Blizzara Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these were causing wrong nullability for the type in the created pojos. I don't think that type field is used anywhere so it didn't cause harm, but still failed roundtrip tests as the type isn't written in proto and then it got correctly evaluated from other fields on read.

}
val aggOutput = aggOutputMap.map(_._1)

Expand All @@ -145,7 +145,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}
val groupOutputMap = actualGroupExprs.zipWithIndex.map {
case (e, i) =>
AttributeReference(s"group_col_$i", e.dataType)() -> e
AttributeReference(s"group_col_$i", e.dataType, nullable = e.nullable)() -> e
}
val groupOutput = groupOutputMap.map(_._1)

Expand Down Expand Up @@ -185,13 +185,10 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitWindow(window: Window): relation.Rel = {
val windowExpressions = window.windowExpressions.map {
case w: WindowExpression => fromWindowCall(w, window.child.output)
case a: Alias if a.child.isInstanceOf[WindowExpression] =>
fromWindowCall(a.child.asInstanceOf[WindowExpression], window.child.output)
case other =>
throw new UnsupportedOperationException(s"Unsupported window expression: $other")
}.asJava
val windowExpressions = window.windowExpressions
.flatMap(expr => expr.collect { case w: WindowExpression => w })
.map(fromWindowCall(_, window.child.output))
.asJava
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the previous code would fail if there was a column like Alias(Alias(WindowExpression(..))), this catches those. It doesn't explicitly fail if there's some other wrappers for WindowExpressions than Alias, but I hope that's not the case in valid Spark plans


val partitionExpressions = window.partitionSpec.map(toExpression(window.child.output)).asJava
val sorts = window.orderSpec.map(toSortField(window.child.output)).asJava
Expand All @@ -209,12 +206,15 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

private def fetch(child: LogicalPlan, offset: Long, limit: Long = -1): relation.Fetch = {
relation.Fetch
val builder = relation.Fetch
.builder()
.input(visit(child))
.offset(offset)
.count(limit)
.build()
if (limit != -1) {
builder.count(limit)
}

builder.build()
}

override def visitGlobalLimit(p: GlobalLimit): relation.Rel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)
val pojoFromProto =
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)
assertResult(pojoRel)(pojoFromProto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
Expand Down
Loading