Skip to content

Commit

Permalink
fix(spark): incorrect conversion of expand relation (#316)
Browse files Browse the repository at this point in the history
In the expand relation, the projection expressions are stored in a
two dimensional array.  The spark matrix needs to be transposed
in order to map the expressions into substrait, and vice-versa.  I hadn’t noticed this earlier.

Also, the remap field should not be used because the outputs
are defined directly in the projections array.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman authored Nov 20, 2024
1 parent e3139c6 commit 6c78d48
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 10 deletions.
7 changes: 4 additions & 3 deletions core/src/main/java/io/substrait/relation/Expand.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Enclosing
Expand All @@ -18,7 +17,7 @@ public abstract class Expand extends SingleInputRel {
public Type.Struct deriveRecordType() {
Type.Struct initial = getInput().getRecordType();
return TypeCreator.of(initial.nullable())
.struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64)));
.struct(getFields().stream().map(ExpandField::getType));
}

@Override
Expand Down Expand Up @@ -52,7 +51,9 @@ public abstract static class SwitchingField implements ExpandField {
public abstract List<Expression> getDuplicates();

public Type getType() {
return getDuplicates().get(0).getType();
var nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable());
var type = getDuplicates().get(0).getType();
return nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type);
}

public static ImmutableExpand.SwitchingField.Builder builder() {
Expand Down
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/spark/SparkExtension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ object SparkExtension {
private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
SimpleExtension.loadDefaults()

val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls)

lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = {
val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]()
ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
}

// An output column is nullable if any of the projections can assign null to it
val types = projections.transpose.map(p => (p.head.dataType, p.exists(_.nullable)))

val output = types
val output = projections
.map(p => (p.head.dataType, p.exists(_.nullable)))
.zip(names)
.map { case (t, name) => StructField(name, t._1, t._2) }
.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

Expand(projections, output, child)
Expand(projections.transpose, output, child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
}

override def visitExpand(p: Expand): relation.Rel = {
val fields = p.projections.map(
val fields = p.projections.transpose.map(
proj => {
relation.Expand.SwitchingField.builder
.duplicates(
Expand All @@ -302,7 +302,6 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
val names = p.output.map(_.name)

relation.Expand.builder
.remap(relation.Rel.Remap.offset(p.child.output.size, names.size))
.fields(fields.asJava)
.hint(Hint.builder.addAllOutputNames(names.asJava).build())
.input(visit(p.child))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import io.substrait.debug.TreePrinter
import io.substrait.extension.ExtensionCollector
import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter}
import io.substrait.proto
import io.substrait.relation.RelProtoConverter
import io.substrait.relation.{ProtoRelConverter, RelProtoConverter}
import org.scalactic.Equality
import org.scalactic.source.Position
import org.scalatest.Succeeded
Expand Down Expand Up @@ -93,6 +93,10 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
require(logicalPlan2.resolved);
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)

val extensionCollector = new ExtensionCollector;
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)

pojoRel2.shouldEqualPlainly(pojoRel)
logicalPlan2
}
Expand Down

0 comments on commit 6c78d48

Please sign in to comment.