Skip to content

Commit

Permalink
feat: add NestedLoopJoin rel
Browse files Browse the repository at this point in the history
  • Loading branch information
danepitkin committed Oct 10, 2023
1 parent 0689c6b commit c88d941
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 0 deletions.
25 changes: 25 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NestedLoopJoin;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
Expand Down Expand Up @@ -188,6 +189,30 @@ private NamedScan namedScan(
return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
}

public NestedLoopJoin nestedLoopJoin(
Function<JoinInput, Expression> conditionFn,
NestedLoopJoin.JoinType joinType,
Rel left,
Rel right) {
return nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right);
}

private NestedLoopJoin nestedLoopJoin(
Function<JoinInput, Expression> conditionFn,
NestedLoopJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
var condition = conditionFn.apply(new JoinInput(left, right));
return NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(condition)
.joinType(joinType)
.remap(remap)
.build();
}

public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
return project(expressionsFn, Optional.empty(), input);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public OUTPUT visit(Join join) throws EXCEPTION {
return visitFallback(join);
}

@Override
public OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION {
return visitFallback(nestedLoopJoin);
}

@Override
public OUTPUT visit(Set set) throws EXCEPTION {
return visitFallback(set);
Expand Down
75 changes: 75 additions & 0 deletions core/src/main/java/io/substrait/relation/NestedLoopJoin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Optional;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
public abstract class NestedLoopJoin extends BiRel implements HasExtension {

public abstract Optional<Expression> getCondition();

public abstract JoinType getJoinType();

public static enum JoinType {
UNKNOWN(NestedLoopJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED),
INNER(NestedLoopJoinRel.JoinType.JOIN_TYPE_INNER),
OUTER(NestedLoopJoinRel.JoinType.JOIN_TYPE_OUTER),
LEFT(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT),
RIGHT(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT),
LEFT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI),
RIGHT_SEMI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI),
LEFT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI),
RIGHT_ANTI(NestedLoopJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI);

private NestedLoopJoinRel.JoinType proto;

JoinType(NestedLoopJoinRel.JoinType proto) {
this.proto = proto;
}

public NestedLoopJoinRel.JoinType toProto() {
return proto;
}

public static JoinType fromProto(NestedLoopJoinRel.JoinType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}

throw new IllegalArgumentException("Unknown type: " + proto);
}
}

@Override
protected Type.Struct deriveRecordType() {
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, RIGHT_SEMI, RIGHT_ANTI, OUTER -> getLeft().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, LEFT_SEMI, LEFT_ANTI, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableNestedLoopJoin.Builder builder() {
return ImmutableNestedLoopJoin.builder();
}
}
27 changes: 27 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.SetRel;
Expand Down Expand Up @@ -77,6 +78,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case JOIN -> {
return newJoin(rel.getJoin());
}
case NESTED_LOOP_JOIN -> {
return newNestedLoopJoin(rel.getNestedLoopJoin());
}
case SET -> {
return newSet(rel.getSet());
}
Expand Down Expand Up @@ -460,6 +464,29 @@ private Join newJoin(JoinRel rel) {
return builder.build();
}

private NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
NestedLoopJoin.builder()
.left(left)
.right(right)
.condition(converter.from(rel.getExpression()))
.joinType(NestedLoopJoin.JoinType.fromProto(rel.getType()));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private Rel newCross(CrossRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ public Optional<Rel> visit(Join join) throws RuntimeException {
.build());
}

@Override
public Optional<Rel> visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var left = nestedLoopJoin.getLeft().accept(this);
var right = nestedLoopJoin.getRight().accept(this);
var condition = nestedLoopJoin.getCondition().flatMap(t -> visitExpression(t));
if (allEmpty(left, right, condition)) {
return Optional.empty();
}
return Optional.of(
ImmutableNestedLoopJoin.builder()
.from(nestedLoopJoin)
.left(left.orElse(nestedLoopJoin.getLeft()))
.right(right.orElse(nestedLoopJoin.getRight()))
.condition(
Optional.ofNullable(
condition.orElseGet(() -> nestedLoopJoin.getCondition().orElse(null))))
.build());
}

@Override
public Optional<Rel> visit(Set set) throws RuntimeException {
return transformList(set.getInputs(), t -> t.accept(this))
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.Rel;
Expand Down Expand Up @@ -170,6 +171,21 @@ public Rel visit(Join join) throws RuntimeException {
return Rel.newBuilder().setJoin(builder).build();
}

@Override
public Rel visit(NestedLoopJoin nestedLoopJoin) throws RuntimeException {
var builder =
NestedLoopJoinRel.newBuilder()
.setCommon(common(nestedLoopJoin))
.setLeft(toProto(nestedLoopJoin.getLeft()))
.setRight(toProto(nestedLoopJoin.getRight()))
.setType(nestedLoopJoin.getJoinType().toProto());

nestedLoopJoin.getCondition().ifPresent(t -> builder.setExpression(toProto(t)));

nestedLoopJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setNestedLoopJoin(builder).build();
}

@Override
public Rel visit(Set set) throws RuntimeException {
var builder = SetRel.newBuilder().setCommon(common(set)).setOp(set.getSetOp().toProto());
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/java/io/substrait/relation/RelVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {

OUTPUT visit(Join join) throws EXCEPTION;

OUTPUT visit(NestedLoopJoin nestedLoopJoin) throws EXCEPTION;

OUTPUT visit(Set set) throws EXCEPTION;

OUTPUT visit(NamedScan namedScan) throws EXCEPTION;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.substrait.relation.Join;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NestedLoopJoin;
import io.substrait.relation.Project;
import io.substrait.relation.ProtoRelConverter;
import io.substrait.relation.Rel;
Expand Down Expand Up @@ -174,6 +175,19 @@ void join() {
verifyRoundTrip(rel);
}

@Test
void nested_loop_join() {
Rel rel =
NestedLoopJoin.builder()
.from(
b.nestedLoopJoin(
__ -> b.bool(true), NestedLoopJoin.JoinType.INNER, commonTable, commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
verifyRoundTrip(rel);
}

@Test
void project() {
Rel rel =
Expand Down

0 comments on commit c88d941

Please sign in to comment.