Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial HashJoinRel support #187

Merged
merged 12 commits into from
Oct 25, 2023
29 changes: 29 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -165,6 +166,34 @@ private Join join(
.build();
}

public HashJoin hashJoin(
List<Integer> leftKeys,
List<Integer> rightKeys,
HashJoin.JoinType joinType,
Rel left,
Rel right) {
return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
}

public HashJoin hashJoin(
List<Integer> leftKeys,
List<Integer> rightKeys,
HashJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
return HashJoin.builder()
.left(left)
.right(right)
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.joinType(joinType)
.remap(remap)
.build();
}

public NamedScan namedScan(
Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
return namedScan(tableName, columnNames, types, Optional.empty());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;

public abstract class AbstractRelVisitor<OUTPUT, EXCEPTION extends Exception>
implements RelVisitor<OUTPUT, EXCEPTION> {
public abstract OUTPUT visitFallback(Rel rel);
Expand Down Expand Up @@ -83,4 +85,9 @@ public OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION {
public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION {
return visitFallback(extensionTable);
}

@Override
public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION {
return visitFallback(hashJoin);
}
}
37 changes: 37 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.proto.ExtensionSingleRel;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -27,6 +28,7 @@
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -95,6 +97,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case EXTENSION_MULTI -> {
return newExtensionMulti(rel.getExtensionMulti());
}
case HASH_JOIN -> {
return newHashJoin(rel.getHashJoin());
}
default -> {
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
}
Expand Down Expand Up @@ -490,6 +495,38 @@ private Set newSet(SetRel rel) {
return builder.build();
}

private Rel newHashJoin(HashJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
var leftKeys = rel.getLeftKeysList();
var rightKeys = rel.getRightKeysList();

Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this);
var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this);
var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
HashJoin.builder()
.left(left)
.right(right)
.leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList()))
.rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList()))
.joinType(HashJoin.JoinType.fromProto(rel.getType()))
.postJoinFilter(
Optional.ofNullable(
rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.ImmutableHashJoin;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -166,6 +168,29 @@ public Optional<Rel> visit(Cross cross) throws RuntimeException {
.build());
}

@Override
public Optional<Rel> visit(HashJoin hashJoin) throws RuntimeException {
var left = hashJoin.getLeft().accept(this);
var right = hashJoin.getRight().accept(this);
var leftKeys = hashJoin.getLeftKeys();
var rightKeys = hashJoin.getRightKeys();
var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t));
if (allEmpty(left, right, postFilter)) {
return Optional.empty();
}
vbarua marked this conversation as resolved.
Show resolved Hide resolved
return Optional.of(
ImmutableHashJoin.builder()
.from(hashJoin)
.left(left.orElse(hashJoin.getLeft()))
.right(right.orElse(hashJoin.getRight()))
.leftKeys(leftKeys)
.rightKeys(rightKeys)
.postJoinFilter(
Optional.ofNullable(
postFilter.orElseGet(() -> hashJoin.getPostJoinFilter().orElse(null))))
vbarua marked this conversation as resolved.
Show resolved Hide resolved
.build());
}

private Optional<Expression> visitExpression(Expression expression) {
ExpressionVisitor<Optional<Expression>, RuntimeException> visitor =
new AbstractExpressionVisitor<>() {
Expand Down
34 changes: 34 additions & 0 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
Expand All @@ -12,6 +13,7 @@
import io.substrait.proto.ExtensionSingleRel;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -21,6 +23,7 @@
import io.substrait.proto.SortField;
import io.substrait.proto.SortRel;
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.proto.TypeProtoConverter;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -68,6 +71,10 @@ private List<SortField> toProtoS(Collection<Expression.SortField> sorts) {
.collect(java.util.stream.Collectors.toList());
}

private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) {
return fieldReference.accept(exprProtoConverter).getSelection();
}

@Override
public Rel visit(Aggregate aggregate) throws RuntimeException {
var builder =
Expand Down Expand Up @@ -166,6 +173,8 @@ public Rel visit(Join join) throws RuntimeException {

join.getCondition().ifPresent(t -> builder.setExpression(toProto(t)));

join.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

join.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setJoin(builder).build();
}
Expand Down Expand Up @@ -228,6 +237,31 @@ public Rel visit(ExtensionTable extensionTable) throws RuntimeException {
return Rel.newBuilder().setRead(builder).build();
}

@Override
public Rel visit(HashJoin hashJoin) throws RuntimeException {
var builder =
HashJoinRel.newBuilder()
.setCommon(common(hashJoin))
.setLeft(toProto(hashJoin.getLeft()))
.setRight(toProto(hashJoin.getRight()))
.setType(hashJoin.getJoinType().toProto());

List<FieldReference> leftKeys = hashJoin.getLeftKeys();
List<FieldReference> rightKeys = hashJoin.getRightKeys();

if (leftKeys.size() != rightKeys.size()) {
throw new RuntimeException("Number of left and right keys must be equal.");
}

builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList()));
builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList()));

hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

vbarua marked this conversation as resolved.
Show resolved Hide resolved
hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setHashJoin(builder).build();
}

@Override
public Rel visit(Project project) throws RuntimeException {
var builder =
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/java/io/substrait/relation/RelVisitor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.substrait.relation;

import io.substrait.relation.physical.HashJoin;

public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(Aggregate aggregate) throws EXCEPTION;

Expand Down Expand Up @@ -32,4 +34,6 @@ public interface RelVisitor<OUTPUT, EXCEPTION extends Exception> {
OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION;

OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION;

OUTPUT visit(HashJoin hashJoin) throws EXCEPTION;
}
85 changes: 85 additions & 0 deletions core/src/main/java/io/substrait/relation/physical/HashJoin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.substrait.relation.physical;

import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.proto.HashJoinRel;
import io.substrait.relation.BiRel;
import io.substrait.relation.HasExtension;
import io.substrait.relation.RelVisitor;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Immutable
public abstract class HashJoin extends BiRel implements HasExtension {

public abstract List<FieldReference> getLeftKeys();

public abstract List<FieldReference> getRightKeys();

public abstract JoinType getJoinType();

public abstract Optional<Expression> getPostJoinFilter();

public static enum JoinType {
UNKNOWN(HashJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED),
INNER(HashJoinRel.JoinType.JOIN_TYPE_INNER),
OUTER(HashJoinRel.JoinType.JOIN_TYPE_OUTER),
LEFT(HashJoinRel.JoinType.JOIN_TYPE_LEFT),
RIGHT(HashJoinRel.JoinType.JOIN_TYPE_RIGHT),
LEFT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI),
RIGHT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI),
LEFT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI),
RIGHT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI);

private HashJoinRel.JoinType proto;

JoinType(HashJoinRel.JoinType proto) {
this.proto = proto;
}

public static JoinType fromProto(HashJoinRel.JoinType proto) {
for (var v : values()) {
if (v.proto == proto) {
return v;
}
}
throw new IllegalArgumentException("Unknown type: " + proto);
}

public HashJoinRel.JoinType toProto() {
return proto;
}
}

@Override
protected Type.Struct deriveRecordType() {
Stream<Type> leftTypes =
switch (getJoinType()) {
case RIGHT, OUTER -> getLeft().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty();
default -> getLeft().getRecordType().fields().stream();
};
Stream<Type> rightTypes =
switch (getJoinType()) {
case LEFT, OUTER -> getRight().getRecordType().fields().stream()
.map(TypeCreator::asNullable);
case LEFT_ANTI, LEFT_SEMI -> Stream.empty();
default -> getRight().getRecordType().fields().stream();
};
return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableHashJoin.Builder builder() {
return ImmutableHashJoin.builder();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.utils.StringHolder;
import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Nested;
Expand Down Expand Up @@ -174,6 +176,26 @@ void join() {
verifyRoundTrip(rel);
}

@Test
void hashJoin() {
// with empty keys
List<Integer> leftEmptyKeys = Collections.emptyList();
List<Integer> rightEmptyKeys = Collections.emptyList();
Rel relWithoutKeys =
HashJoin.builder()
.from(
b.hashJoin(
leftEmptyKeys,
rightEmptyKeys,
HashJoin.JoinType.INNER,
commonTable,
commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
verifyRoundTrip(relWithoutKeys);
}

@Test
void project() {
Rel rel =
Expand Down
Loading