Skip to content

Commit

Permalink
fix: addressing reviews v1
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Oct 19, 2023
1 parent d6b3165 commit da6339e
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 29 deletions.
16 changes: 11 additions & 5 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,22 +167,28 @@ private Join join(
}

public HashJoin hashJoin(
int[] leftKeys, int[] rightKeys, HashJoin.JoinType joinType, Rel left, Rel right) {
List<Integer> leftKeys,
List<Integer> rightKeys,
HashJoin.JoinType joinType,
Rel left,
Rel right) {
return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
}

public HashJoin hashJoin(
int[] leftKeys,
int[] rightKeys,
List<Integer> leftKeys,
List<Integer> rightKeys,
HashJoin.JoinType joinType,
Optional<Rel.Remap> remap,
Rel left,
Rel right) {
return HashJoin.builder()
.left(left)
.right(right)
.leftKeys(this.fieldReferences(left, leftKeys))
.rightKeys(this.fieldReferences(right, rightKeys))
.leftKeys(
this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray()))
.rightKeys(
this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray()))
.joinType(joinType)
.remap(remap)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public Optional<Rel> visit(HashJoin hashJoin) throws RuntimeException {
var leftKeys = hashJoin.getLeftKeys();
var rightKeys = hashJoin.getRightKeys();
var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t));
if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) {
if (allEmpty(left, right, postFilter)) {
return Optional.empty();
}
return Optional.of(
Expand Down
30 changes: 14 additions & 16 deletions core/src/main/java/io/substrait/relation/RelProtoConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ public Rel visit(Join join) throws RuntimeException {

join.getCondition().ifPresent(t -> builder.setExpression(toProto(t)));

join.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

join.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setJoin(builder).build();
}
Expand Down Expand Up @@ -244,23 +246,19 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException {
.setRight(toProto(hashJoin.getRight()))
.setType(hashJoin.getJoinType().toProto());

hashJoin
.getLeftKeys()
.ifPresent(
keys -> {
for (int i = 0; i < keys.size(); i++) {
builder.setLeftKeys(i, toProto(keys.get(i)));
}
});
List<FieldReference> leftKeys = hashJoin.getLeftKeys();
List<FieldReference> rightKeys = hashJoin.getRightKeys();

hashJoin
.getRightKeys()
.ifPresent(
keys -> {
for (int i = 0; i < keys.size(); i++) {
builder.setRightKeys(i, toProto(keys.get(i)));
}
});
if (leftKeys.size() != rightKeys.size()) {
throw new RuntimeException("Number of left and right keys must be equal.");
}

for (int idx = 0; idx < hashJoin.getLeftKeys().size(); idx++) {
builder.setLeftKeys(idx, toProto(leftKeys.get(idx)));
builder.setRightKeys(idx, toProto(rightKeys.get(idx)));
}

hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t)));

hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto()));
return Rel.newBuilder().setHashJoin(builder).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
@Value.Immutable
public abstract class HashJoin extends BiRel implements HasExtension {

public abstract Optional<List<FieldReference>> getLeftKeys();
public abstract List<FieldReference> getLeftKeys();

public abstract Optional<List<FieldReference>> getRightKeys();
public abstract List<FieldReference> getRightKeys();

public abstract JoinType getJoinType();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.Nested;
Expand Down Expand Up @@ -177,17 +178,22 @@ void join() {

@Test
void hashJoin() {
int[] left_keys = {};
int[] right_keys = {};
Rel rel =
// with empty keys
List<Integer> leftEmptyKeys = Collections.emptyList();
List<Integer> rightEmptyKeys = Collections.emptyList();
Rel relWithoutKeys =
HashJoin.builder()
.from(
b.hashJoin(
left_keys, right_keys, HashJoin.JoinType.INNER, commonTable, commonTable))
leftEmptyKeys,
rightEmptyKeys,
HashJoin.JoinType.INNER,
commonTable,
commonTable))
.commonExtension(commonExtension)
.extension(relExtension)
.build();
verifyRoundTrip(rel);
verifyRoundTrip(relWithoutKeys);
}

@Test
Expand Down

0 comments on commit da6339e

Please sign in to comment.