From 0b4ba3bf8f221f0a08457d29a0215bec9976329b Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 4 Oct 2023 04:40:38 +0530 Subject: [PATCH 01/12] feat: initial commit --- .../io/substrait/relation/AbstractJoin.java | 9 +++ .../main/java/io/substrait/relation/Join.java | 45 +++++------ .../relation/RelCopyOnWriteVisitor.java | 25 ++++++ .../substrait/relation/RelProtoConverter.java | 32 ++++---- .../io/substrait/relation/RelVisitor.java | 4 + .../substrait/relation/physical/HashJoin.java | 79 +++++++++++++++++++ 6 files changed, 155 insertions(+), 39 deletions(-) create mode 100644 core/src/main/java/io/substrait/relation/AbstractJoin.java create mode 100644 core/src/main/java/io/substrait/relation/physical/HashJoin.java diff --git a/core/src/main/java/io/substrait/relation/AbstractJoin.java b/core/src/main/java/io/substrait/relation/AbstractJoin.java new file mode 100644 index 00000000..c542e22f --- /dev/null +++ b/core/src/main/java/io/substrait/relation/AbstractJoin.java @@ -0,0 +1,9 @@ +package io.substrait.relation; + +import io.substrait.expression.Expression; +import java.util.Optional; + +public abstract class AbstractJoin extends BiRel implements HasExtension { + + public abstract Optional getPostJoinFilter(); +} diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index b571c732..1aa48c57 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -9,14 +9,29 @@ import org.immutables.value.Value; @Value.Immutable -public abstract class Join extends BiRel implements HasExtension { +public abstract class Join extends AbstractJoin { public abstract Optional getCondition(); - public abstract Optional getPostJoinFilter(); - public abstract JoinType getJoinType(); + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + public static enum JoinType { UNKNOWN(JoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), INNER(JoinRel.JoinType.JOIN_TYPE_INNER), @@ -32,36 +47,18 @@ public static enum JoinType { this.proto = proto; } - public JoinRel.JoinType toProto() { - return proto; - } - public static JoinType fromProto(JoinRel.JoinType proto) { for (var v : values()) { if (v.proto == proto) { return v; } } - throw new IllegalArgumentException("Unknown type: " + proto); } - } - @Override - protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - default -> getRight().getRecordType().fields().stream(); - }; - return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + public JoinRel.JoinType toProto() { + return proto; + } } @Override diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 14ef98a0..3463d814 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -6,6 +6,8 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.ImmutableFieldReference; +import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.ImmutableHashJoin; import io.substrait.type.Type; import java.util.ArrayList; import java.util.List; @@ -166,6 +168,29 @@ public Optional visit(Cross cross) throws RuntimeException { .build()); } + @Override + public Optional visit(HashJoin hashJoin) throws RuntimeException { + var left = hashJoin.getLeft().accept(this); + var right = hashJoin.getRight().accept(this); + // var condition = join.getCondition().flatMap(t -> visitExpression(t)); + var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t)); + // if (allEmpty(left, right, condition, postFilter)) { + // return Optional.empty(); + // } + return Optional.of( + ImmutableHashJoin.builder() + .from(hashJoin) + .left(left.orElse(hashJoin.getLeft())) + .right(right.orElse(hashJoin.getRight())) + // .condition( + // Optional.ofNullable(condition.orElseGet(() -> + // join.getCondition().orElse(null)))) + .postJoinFilter( + Optional.ofNullable( + postFilter.orElseGet(() -> hashJoin.getPostJoinFilter().orElse(null)))) + .build()); + } + private Optional visitExpression(Expression expression) { ExpressionVisitor, RuntimeException> visitor = new AbstractExpressionVisitor<>() { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 25c18103..225e4341 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -4,23 +4,10 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; -import io.substrait.proto.AggregateFunction; -import io.substrait.proto.AggregateRel; -import io.substrait.proto.CrossRel; -import io.substrait.proto.ExtensionLeafRel; -import io.substrait.proto.ExtensionMultiRel; -import io.substrait.proto.ExtensionSingleRel; -import io.substrait.proto.FetchRel; -import io.substrait.proto.FilterRel; -import io.substrait.proto.JoinRel; -import io.substrait.proto.ProjectRel; -import io.substrait.proto.ReadRel; +import io.substrait.proto.*; import io.substrait.proto.Rel; -import io.substrait.proto.RelCommon; -import io.substrait.proto.SetRel; -import io.substrait.proto.SortField; -import io.substrait.proto.SortRel; import io.substrait.relation.files.FileOrFiles; +import io.substrait.relation.physical.HashJoin; import io.substrait.type.proto.TypeProtoConverter; import java.util.Collection; import java.util.List; @@ -228,6 +215,21 @@ public Rel visit(ExtensionTable extensionTable) throws RuntimeException { return Rel.newBuilder().setRead(builder).build(); } + @Override + public Rel visit(HashJoin hashJoin) throws RuntimeException { + var builder = + HashJoinRel.newBuilder() + .setCommon(common(hashJoin)) + .setLeft(toProto(hashJoin.getLeft())) + .setRight(toProto(hashJoin.getRight())) + .setType(hashJoin.getJoinType().toProto()); + + // hashJoin.getLeftKeys().ifPresent(t -> builder.setLeftKeys()); + + hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); + return Rel.newBuilder().setHashJoin(builder).build(); + } + @Override public Rel visit(Project project) throws RuntimeException { var builder = diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index c685cd98..e8e78aaf 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -1,5 +1,7 @@ package io.substrait.relation; +import io.substrait.relation.physical.HashJoin; + public interface RelVisitor { OUTPUT visit(Aggregate aggregate) throws EXCEPTION; @@ -32,4 +34,6 @@ public interface RelVisitor { OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION; OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION; + + OUTPUT visit(HashJoin hashJoin) throws EXCEPTION; } diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java new file mode 100644 index 00000000..68e297b3 --- /dev/null +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -0,0 +1,79 @@ +package io.substrait.relation.physical; + +import io.substrait.expression.FieldReference; +import io.substrait.proto.HashJoinRel; +import io.substrait.relation.AbstractJoin; +import io.substrait.relation.RelVisitor; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class HashJoin extends AbstractJoin { + + public abstract Optional> getLeftKeys(); + + public abstract Optional> getRightKeys(); + + public abstract JoinType getJoinType(); + + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + + public static enum JoinType { + UNKNOWN(HashJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), + INNER(HashJoinRel.JoinType.JOIN_TYPE_INNER), + OUTER(HashJoinRel.JoinType.JOIN_TYPE_OUTER), + LEFT(HashJoinRel.JoinType.JOIN_TYPE_LEFT), + RIGHT(HashJoinRel.JoinType.JOIN_TYPE_RIGHT), + LEFT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI), + RIGHT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI), + LEFT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI), + RIGHT_ANTO(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); + + private HashJoinRel.JoinType proto; + + JoinType(HashJoinRel.JoinType proto) { + this.proto = proto; + } + + public static JoinType fromProto(HashJoinRel.JoinType proto) { + for (var v : values()) { + if (v.proto == proto) { + return v; + } + } + throw new IllegalArgumentException("Unknown type: " + proto); + } + + public HashJoinRel.JoinType toProto() { + return proto; + } + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableHashJoin.Builder builder() { + return ImmutableHashJoin.builder(); + } +} From 985b10207fa1e4c4ee7d0f7a479ca44b7ec12eb3 Mon Sep 17 00:00:00 2001 From: vibhatha Date: Wed, 4 Oct 2023 21:41:35 +0530 Subject: [PATCH 02/12] feat: initial feature v1 --- .../relation/AbstractRelVisitor.java | 7 ++++++ .../relation/RelCopyOnWriteVisitor.java | 14 +++++------ .../substrait/relation/RelProtoConverter.java | 23 ++++++++++++++++++- .../substrait/relation/physical/HashJoin.java | 2 +- 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index f46e8899..645f692e 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -1,5 +1,7 @@ package io.substrait.relation; +import io.substrait.relation.physical.HashJoin; + public abstract class AbstractRelVisitor implements RelVisitor { public abstract OUTPUT visitFallback(Rel rel); @@ -83,4 +85,9 @@ public OUTPUT visit(ExtensionMulti extensionMulti) throws EXCEPTION { public OUTPUT visit(ExtensionTable extensionTable) throws EXCEPTION { return visitFallback(extensionTable); } + + @Override + public OUTPUT visit(HashJoin hashJoin) throws EXCEPTION { + return visitFallback(hashJoin); + } } diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 3463d814..03d87458 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -172,19 +172,19 @@ public Optional visit(Cross cross) throws RuntimeException { public Optional visit(HashJoin hashJoin) throws RuntimeException { var left = hashJoin.getLeft().accept(this); var right = hashJoin.getRight().accept(this); - // var condition = join.getCondition().flatMap(t -> visitExpression(t)); + var leftKeys = hashJoin.getLeftKeys(); + var rightKeys = hashJoin.getRightKeys(); var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t)); - // if (allEmpty(left, right, condition, postFilter)) { - // return Optional.empty(); - // } + if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { + return Optional.empty(); + } return Optional.of( ImmutableHashJoin.builder() .from(hashJoin) .left(left.orElse(hashJoin.getLeft())) .right(right.orElse(hashJoin.getRight())) - // .condition( - // Optional.ofNullable(condition.orElseGet(() -> - // join.getCondition().orElse(null)))) + .leftKeys(leftKeys) + .rightKeys(rightKeys) .postJoinFilter( Optional.ofNullable( postFilter.orElseGet(() -> hashJoin.getPostJoinFilter().orElse(null)))) diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 225e4341..c42d84dc 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; @@ -55,6 +56,10 @@ private List toProtoS(Collection sorts) { .collect(java.util.stream.Collectors.toList()); } + private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) { + return fieldReference.accept(exprProtoConverter).getSelection(); + } + @Override public Rel visit(Aggregate aggregate) throws RuntimeException { var builder = @@ -224,7 +229,23 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException { .setRight(toProto(hashJoin.getRight())) .setType(hashJoin.getJoinType().toProto()); - // hashJoin.getLeftKeys().ifPresent(t -> builder.setLeftKeys()); + hashJoin + .getLeftKeys() + .ifPresent( + keys -> { + for (int i = 0; i < keys.size(); i++) { + builder.setLeftKeys(i, toProto(keys.get(i))); + } + }); + + hashJoin + .getRightKeys() + .ifPresent( + keys -> { + for (int i = 0; i < keys.size(); i++) { + builder.setRightKeys(i, toProto(keys.get(i))); + } + }); hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setHashJoin(builder).build(); diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index 68e297b3..6a6c2a27 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -46,7 +46,7 @@ public static enum JoinType { LEFT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_SEMI), RIGHT_SEMI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_SEMI), LEFT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_LEFT_ANTI), - RIGHT_ANTO(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); + RIGHT_ANTI(HashJoinRel.JoinType.JOIN_TYPE_RIGHT_ANTI); private HashJoinRel.JoinType proto; From aa3e3904fa59e9b01024824222085e795d553c32 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 5 Oct 2023 12:07:55 +0530 Subject: [PATCH 03/12] fix: formatting --- .../substrait/relation/RelProtoConverter.java | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index c42d84dc..a9707b78 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -5,8 +5,23 @@ import io.substrait.expression.FunctionArg; import io.substrait.expression.proto.ExpressionProtoConverter; import io.substrait.extension.ExtensionCollector; -import io.substrait.proto.*; +import io.substrait.proto.AggregateFunction; +import io.substrait.proto.AggregateRel; +import io.substrait.proto.CrossRel; +import io.substrait.proto.ExtensionLeafRel; +import io.substrait.proto.ExtensionMultiRel; +import io.substrait.proto.ExtensionSingleRel; +import io.substrait.proto.FetchRel; +import io.substrait.proto.FilterRel; +import io.substrait.proto.HashJoinRel; +import io.substrait.proto.JoinRel; +import io.substrait.proto.ProjectRel; +import io.substrait.proto.ReadRel; import io.substrait.proto.Rel; +import io.substrait.proto.RelCommon; +import io.substrait.proto.SetRel; +import io.substrait.proto.SortField; +import io.substrait.proto.SortRel; import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.physical.HashJoin; import io.substrait.type.proto.TypeProtoConverter; From 1695370005ceb0d494a398a4514b96aff07808f3 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Tue, 10 Oct 2023 18:22:17 +0530 Subject: [PATCH 04/12] feat: include protorel conerter --- .../substrait/relation/ProtoRelConverter.java | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 0cc846f2..1db5dc96 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -17,6 +17,7 @@ import io.substrait.proto.ExtensionSingleRel; import io.substrait.proto.FetchRel; import io.substrait.proto.FilterRel; +import io.substrait.proto.HashJoinRel; import io.substrait.proto.JoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; @@ -27,6 +28,7 @@ import io.substrait.relation.files.FileOrFiles; import io.substrait.relation.files.ImmutableFileFormat; import io.substrait.relation.files.ImmutableFileOrFiles; +import io.substrait.relation.physical.HashJoin; import io.substrait.type.ImmutableNamedStruct; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -95,6 +97,9 @@ public Rel from(io.substrait.proto.Rel rel) { case EXTENSION_MULTI -> { return newExtensionMulti(rel.getExtensionMulti()); } + case HASH_JOIN -> { + return newHashJoin(rel.getHashJoin()); + } default -> { throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType); } @@ -490,6 +495,35 @@ private Set newSet(SetRel rel) { return builder.build(); } + private Rel newHashJoin(HashJoinRel rel) { + Rel left = from(rel.getLeft()); + Rel right = from(rel.getRight()); + Type.Struct leftStruct = left.getRecordType(); + Type.Struct rightStruct = right.getRecordType(); + Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); + var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var builder = + HashJoin.builder() + .left(left) + .right(right) + .leftKeys( + rel.getLeftKeysList().stream().map(converter::from).collect(Collectors.toList())) + .rightKeys( + rel.getRightKeysList().stream().map(converter::from).collect(Collectors.toList())) + .joinType(HashJoin.JoinType.fromProto(rel.getType())) + .postJoinFilter( + Optional.ofNullable( + rel.hasPostJoinFilter() ? converter.from(rel.getPostJoinFilter()) : null)); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())); + if (rel.hasAdvancedExtension()) { + builder.extension(advancedExtension(rel.getAdvancedExtension())); + } + return builder.build(); + } + private static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); From 00c8c7447c61cc91e9e32ea715f318fb982c0973 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 11 Oct 2023 05:34:22 +0530 Subject: [PATCH 05/12] fix: remove abstract join interface --- .../io/substrait/relation/AbstractJoin.java | 9 ---- .../main/java/io/substrait/relation/Join.java | 44 ++++++++++--------- .../substrait/relation/physical/HashJoin.java | 40 +++++++++-------- 3 files changed, 45 insertions(+), 48 deletions(-) delete mode 100644 core/src/main/java/io/substrait/relation/AbstractJoin.java diff --git a/core/src/main/java/io/substrait/relation/AbstractJoin.java b/core/src/main/java/io/substrait/relation/AbstractJoin.java deleted file mode 100644 index c542e22f..00000000 --- a/core/src/main/java/io/substrait/relation/AbstractJoin.java +++ /dev/null @@ -1,9 +0,0 @@ -package io.substrait.relation; - -import io.substrait.expression.Expression; -import java.util.Optional; - -public abstract class AbstractJoin extends BiRel implements HasExtension { - - public abstract Optional getPostJoinFilter(); -} diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index 1aa48c57..5b95beef 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -9,28 +9,13 @@ import org.immutables.value.Value; @Value.Immutable -public abstract class Join extends AbstractJoin { +public abstract class Join extends BiRel implements HasExtension { public abstract Optional getCondition(); - public abstract JoinType getJoinType(); + public abstract Optional getPostJoinFilter(); - @Override - protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - default -> getRight().getRecordType().fields().stream(); - }; - return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); - } + public abstract JoinType getJoinType(); public static enum JoinType { UNKNOWN(JoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), @@ -47,6 +32,10 @@ public static enum JoinType { this.proto = proto; } + public JoinRel.JoinType toProto() { + return proto; + } + public static JoinType fromProto(JoinRel.JoinType proto) { for (var v : values()) { if (v.proto == proto) { @@ -55,10 +44,23 @@ public static JoinType fromProto(JoinRel.JoinType proto) { } throw new IllegalArgumentException("Unknown type: " + proto); } + } - public JoinRel.JoinType toProto() { - return proto; - } + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); } @Override diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index 6a6c2a27..429c76e2 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -1,8 +1,10 @@ package io.substrait.relation.physical; +import io.substrait.expression.Expression; import io.substrait.expression.FieldReference; import io.substrait.proto.HashJoinRel; -import io.substrait.relation.AbstractJoin; +import io.substrait.relation.BiRel; +import io.substrait.relation.HasExtension; import io.substrait.relation.RelVisitor; import io.substrait.type.Type; import io.substrait.type.TypeCreator; @@ -12,7 +14,7 @@ import org.immutables.value.Value; @Value.Immutable -public abstract class HashJoin extends AbstractJoin { +public abstract class HashJoin extends BiRel implements HasExtension { public abstract Optional> getLeftKeys(); @@ -20,22 +22,7 @@ public abstract class HashJoin extends AbstractJoin { public abstract JoinType getJoinType(); - @Override - protected Type.Struct deriveRecordType() { - Stream leftTypes = - switch (getJoinType()) { - case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - default -> getLeft().getRecordType().fields().stream(); - }; - Stream rightTypes = - switch (getJoinType()) { - case LEFT, OUTER -> getRight().getRecordType().fields().stream() - .map(TypeCreator::asNullable); - default -> getRight().getRecordType().fields().stream(); - }; - return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); - } + public abstract Optional getPostJoinFilter(); public static enum JoinType { UNKNOWN(HashJoinRel.JoinType.JOIN_TYPE_UNSPECIFIED), @@ -68,6 +55,23 @@ public HashJoinRel.JoinType toProto() { } } + @Override + protected Type.Struct deriveRecordType() { + Stream leftTypes = + switch (getJoinType()) { + case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getLeft().getRecordType().fields().stream(); + }; + Stream rightTypes = + switch (getJoinType()) { + case LEFT, OUTER -> getRight().getRecordType().fields().stream() + .map(TypeCreator::asNullable); + default -> getRight().getRecordType().fields().stream(); + }; + return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); + } + @Override public O accept(RelVisitor visitor) throws E { return visitor.visit(this); From 37e775f5062964fb08c286063b0e276c206c5dca Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 11 Oct 2023 05:35:26 +0530 Subject: [PATCH 06/12] fix: unwanted line --- core/src/main/java/io/substrait/relation/Join.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/io/substrait/relation/Join.java b/core/src/main/java/io/substrait/relation/Join.java index 5b95beef..b571c732 100644 --- a/core/src/main/java/io/substrait/relation/Join.java +++ b/core/src/main/java/io/substrait/relation/Join.java @@ -42,6 +42,7 @@ public static JoinType fromProto(JoinRel.JoinType proto) { return v; } } + throw new IllegalArgumentException("Unknown type: " + proto); } } From 5f10b13e8c4ce68a8813bd3afd6a762525013fa5 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 11 Oct 2023 05:50:51 +0530 Subject: [PATCH 07/12] fix: adding anti and semi conditions --- core/src/main/java/io/substrait/relation/physical/HashJoin.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index 429c76e2..c78de5e8 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -61,12 +61,14 @@ protected Type.Struct deriveRecordType() { switch (getJoinType()) { case RIGHT, OUTER -> getLeft().getRecordType().fields().stream() .map(TypeCreator::asNullable); + case RIGHT_ANTI, RIGHT_SEMI -> Stream.empty(); default -> getLeft().getRecordType().fields().stream(); }; Stream rightTypes = switch (getJoinType()) { case LEFT, OUTER -> getRight().getRecordType().fields().stream() .map(TypeCreator::asNullable); + case LEFT_ANTI, LEFT_SEMI -> Stream.empty(); default -> getRight().getRecordType().fields().stream(); }; return TypeCreator.REQUIRED.struct(Stream.concat(leftTypes, rightTypes)); From 9064d3e08ed278eea62ed7eafccbb2290601547b Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 11 Oct 2023 06:35:39 +0530 Subject: [PATCH 08/12] feat: adding test cases --- .../io/substrait/dsl/SubstraitBuilder.java | 23 +++++++++++++++++++ .../type/proto/ExtensionRoundtripTest.java | 16 +++++++++++++ 2 files changed, 39 insertions(+) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index e87c071a..f897d65d 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -23,6 +23,7 @@ import io.substrait.relation.Rel; import io.substrait.relation.Set; import io.substrait.relation.Sort; +import io.substrait.relation.physical.HashJoin; import io.substrait.type.ImmutableType; import io.substrait.type.NamedStruct; import io.substrait.type.Type; @@ -165,6 +166,28 @@ private Join join( .build(); } + public HashJoin hashJoin( + int[] leftKeys, int[] rightKeys, HashJoin.JoinType joinType, Rel left, Rel right) { + return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right); + } + + public HashJoin hashJoin( + int[] leftKeys, + int[] rightKeys, + HashJoin.JoinType joinType, + Optional remap, + Rel left, + Rel right) { + return HashJoin.builder() + .left(left) + .right(right) + .leftKeys(this.fieldReferences(left, leftKeys)) + .rightKeys(this.fieldReferences(right, rightKeys)) + .joinType(joinType) + .remap(remap) + .build(); + } + public NamedScan namedScan( Iterable tableName, Iterable columnNames, Iterable types) { return namedScan(tableName, columnNames, types, Optional.empty()); diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index d340cd01..01f00d89 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -26,6 +26,7 @@ import io.substrait.relation.Set; import io.substrait.relation.Sort; import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.physical.HashJoin; import io.substrait.relation.utils.StringHolder; import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; import io.substrait.type.NamedStruct; @@ -174,6 +175,21 @@ void join() { verifyRoundTrip(rel); } + @Test + void hashJoin() { + int[] left_keys = {}; + int[] right_keys = {}; + Rel rel = + HashJoin.builder() + .from( + b.hashJoin( + left_keys, right_keys, HashJoin.JoinType.INNER, commonTable, commonTable)) + .commonExtension(commonExtension) + .extension(relExtension) + .build(); + verifyRoundTrip(rel); + } + @Test void project() { Rel rel = From 4162cc246d3f0127ac2ddb9ee647af3e512fb509 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 19 Oct 2023 09:14:50 +0530 Subject: [PATCH 09/12] fix: addressing reviews v1 --- .../io/substrait/dsl/SubstraitBuilder.java | 16 ++++++---- .../relation/RelCopyOnWriteVisitor.java | 2 +- .../substrait/relation/RelProtoConverter.java | 30 +++++++++---------- .../substrait/relation/physical/HashJoin.java | 4 +-- .../type/proto/ExtensionRoundtripTest.java | 16 ++++++---- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index f897d65d..b330c0d0 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -167,13 +167,17 @@ private Join join( } public HashJoin hashJoin( - int[] leftKeys, int[] rightKeys, HashJoin.JoinType joinType, Rel left, Rel right) { + List leftKeys, + List rightKeys, + HashJoin.JoinType joinType, + Rel left, + Rel right) { return hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right); } public HashJoin hashJoin( - int[] leftKeys, - int[] rightKeys, + List leftKeys, + List rightKeys, HashJoin.JoinType joinType, Optional remap, Rel left, @@ -181,8 +185,10 @@ public HashJoin hashJoin( return HashJoin.builder() .left(left) .right(right) - .leftKeys(this.fieldReferences(left, leftKeys)) - .rightKeys(this.fieldReferences(right, rightKeys)) + .leftKeys( + this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())) + .rightKeys( + this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())) .joinType(joinType) .remap(remap) .build(); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 03d87458..0dddfbd9 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -175,7 +175,7 @@ public Optional visit(HashJoin hashJoin) throws RuntimeException { var leftKeys = hashJoin.getLeftKeys(); var rightKeys = hashJoin.getRightKeys(); var postFilter = hashJoin.getPostJoinFilter().flatMap(t -> visitExpression(t)); - if (allEmpty(left, right, leftKeys, rightKeys, postFilter)) { + if (allEmpty(left, right, postFilter)) { return Optional.empty(); } return Optional.of( diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index a9707b78..8843abdd 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -173,6 +173,8 @@ public Rel visit(Join join) throws RuntimeException { join.getCondition().ifPresent(t -> builder.setExpression(toProto(t))); + join.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); + join.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setJoin(builder).build(); } @@ -244,23 +246,19 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException { .setRight(toProto(hashJoin.getRight())) .setType(hashJoin.getJoinType().toProto()); - hashJoin - .getLeftKeys() - .ifPresent( - keys -> { - for (int i = 0; i < keys.size(); i++) { - builder.setLeftKeys(i, toProto(keys.get(i))); - } - }); + List leftKeys = hashJoin.getLeftKeys(); + List rightKeys = hashJoin.getRightKeys(); - hashJoin - .getRightKeys() - .ifPresent( - keys -> { - for (int i = 0; i < keys.size(); i++) { - builder.setRightKeys(i, toProto(keys.get(i))); - } - }); + if (leftKeys.size() != rightKeys.size()) { + throw new RuntimeException("Number of left and right keys must be equal."); + } + + for (int idx = 0; idx < hashJoin.getLeftKeys().size(); idx++) { + builder.setLeftKeys(idx, toProto(leftKeys.get(idx))); + builder.setRightKeys(idx, toProto(rightKeys.get(idx))); + } + + hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); hashJoin.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setHashJoin(builder).build(); diff --git a/core/src/main/java/io/substrait/relation/physical/HashJoin.java b/core/src/main/java/io/substrait/relation/physical/HashJoin.java index c78de5e8..6d0e68f8 100644 --- a/core/src/main/java/io/substrait/relation/physical/HashJoin.java +++ b/core/src/main/java/io/substrait/relation/physical/HashJoin.java @@ -16,9 +16,9 @@ @Value.Immutable public abstract class HashJoin extends BiRel implements HasExtension { - public abstract Optional> getLeftKeys(); + public abstract List getLeftKeys(); - public abstract Optional> getRightKeys(); + public abstract List getRightKeys(); public abstract JoinType getJoinType(); diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 01f00d89..ad492470 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -33,6 +33,7 @@ import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.Collections; +import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.api.Nested; @@ -177,17 +178,22 @@ void join() { @Test void hashJoin() { - int[] left_keys = {}; - int[] right_keys = {}; - Rel rel = + // with empty keys + List leftEmptyKeys = Collections.emptyList(); + List rightEmptyKeys = Collections.emptyList(); + Rel relWithoutKeys = HashJoin.builder() .from( b.hashJoin( - left_keys, right_keys, HashJoin.JoinType.INNER, commonTable, commonTable)) + leftEmptyKeys, + rightEmptyKeys, + HashJoin.JoinType.INNER, + commonTable, + commonTable)) .commonExtension(commonExtension) .extension(relExtension) .build(); - verifyRoundTrip(rel); + verifyRoundTrip(relWithoutKeys); } @Test From 51549909373163b269760ceccfa8a48cd8d46e67 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Fri, 20 Oct 2023 06:34:24 +0530 Subject: [PATCH 10/12] fix: getting rid of loop and using streaming instead --- .../main/java/io/substrait/relation/RelProtoConverter.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 8843abdd..d9b19f95 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -253,10 +253,8 @@ public Rel visit(HashJoin hashJoin) throws RuntimeException { throw new RuntimeException("Number of left and right keys must be equal."); } - for (int idx = 0; idx < hashJoin.getLeftKeys().size(); idx++) { - builder.setLeftKeys(idx, toProto(leftKeys.get(idx))); - builder.setRightKeys(idx, toProto(rightKeys.get(idx))); - } + builder.addAllLeftKeys(leftKeys.stream().map(this::toProto).collect(Collectors.toList())); + builder.addAllRightKeys(rightKeys.stream().map(this::toProto).collect(Collectors.toList())); hashJoin.getPostJoinFilter().ifPresent(t -> builder.setPostJoinFilter(toProto(t))); From 4c6b42c08372869e154b46000f2ce082130d4c16 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Fri, 20 Oct 2023 08:41:09 +0530 Subject: [PATCH 11/12] fix: address review v2 --- .../substrait/relation/ProtoRelConverter.java | 17 ++-- .../type/proto/JoinRoundtripTest.java | 78 +++++++++++++++++++ 2 files changed, 89 insertions(+), 6 deletions(-) create mode 100644 core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 1db5dc96..8d5641f3 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -39,6 +39,7 @@ import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; /** Converts from {@link io.substrait.proto.Rel} to {@link io.substrait.relation.Rel} */ public class ProtoRelConverter { @@ -498,22 +499,26 @@ private Set newSet(SetRel rel) { private Rel newHashJoin(HashJoinRel rel) { Rel left = from(rel.getLeft()); Rel right = from(rel.getRight()); + var leftKeys = rel.getLeftKeysList(); + var rightKeys = rel.getRightKeysList(); + var rightOffSetKeys = + Stream.concat(leftKeys.stream(), rightKeys.stream()).collect(Collectors.toList()); Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); - var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); + var leftConverter = new ProtoExpressionConverter(lookup, extensions, leftStruct, this); + var rightConverter = new ProtoExpressionConverter(lookup, extensions, rightStruct, this); + var unionConverter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this); var builder = HashJoin.builder() .left(left) .right(right) - .leftKeys( - rel.getLeftKeysList().stream().map(converter::from).collect(Collectors.toList())) - .rightKeys( - rel.getRightKeysList().stream().map(converter::from).collect(Collectors.toList())) + .leftKeys(leftKeys.stream().map(leftConverter::from).collect(Collectors.toList())) + .rightKeys(rightKeys.stream().map(rightConverter::from).collect(Collectors.toList())) .joinType(HashJoin.JoinType.fromProto(rel.getType())) .postJoinFilter( Optional.ofNullable( - rel.hasPostJoinFilter() ? converter.from(rel.getPostJoinFilter()) : null)); + rel.hasPostJoinFilter() ? unionConverter.from(rel.getPostJoinFilter()) : null)); builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java new file mode 100644 index 00000000..c12d1cb3 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -0,0 +1,78 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.TestBase; +import io.substrait.dsl.SubstraitBuilder; +import io.substrait.extension.AdvancedExtension; +import io.substrait.extension.ExtensionCollector; +import io.substrait.extension.SimpleExtension; +import io.substrait.relation.ProtoRelConverter; +import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; +import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.utils.StringHolder; +import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; +import io.substrait.type.TypeCreator; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; + +public class JoinRoundtripTest extends TestBase { + + final SimpleExtension.ExtensionCollection extensions = defaultExtensionCollection; + + TypeCreator R = TypeCreator.REQUIRED; + + final SubstraitBuilder b = new SubstraitBuilder(extensions); + + final ExtensionCollector functionCollector = new ExtensionCollector(); + final RelProtoConverter relProtoConverter = new RelProtoConverter(functionCollector); + final ProtoRelConverter protoRelConverter = + new StringHolderHandlingProtoRelConverter(functionCollector, extensions); + + final Rel leftTable = + b.namedScan( + Arrays.asList("T1"), + Arrays.asList("a", "b", "c"), + Arrays.asList(R.I64, R.FP64, R.STRING)); + + final Rel rightTable = + b.namedScan( + Arrays.asList("T2"), + Arrays.asList("d", "e", "f"), + Arrays.asList(R.FP64, R.STRING, R.I64)); + + final AdvancedExtension commonExtension = + AdvancedExtension.builder() + .enhancement(new StringHolder("COMMON ENHANCEMENT")) + .optimization(new StringHolder("COMMON OPTIMIZATION")) + .build(); + + final AdvancedExtension relExtension = + AdvancedExtension.builder() + .enhancement(new StringHolder("REL ENHANCEMENT")) + .optimization(new StringHolder("REL OPTIMIZATION")) + .build(); + + void verifyRoundTrip(Rel rel) { + io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); + Rel relReturned = protoRelConverter.from(protoRel); + assertEquals(rel, relReturned); + } + + @Test + void hashJoin() { + List leftEmptyKeys = Arrays.asList(0, 1); + List rightEmptyKeys = Arrays.asList(2, 0); + Rel relWithoutKeys = + HashJoin.builder() + .from( + b.hashJoin( + leftEmptyKeys, rightEmptyKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) + .commonExtension(commonExtension) + .extension(relExtension) + .build(); + verifyRoundTrip(relWithoutKeys); + } +} From aeaa0f2f544db61b014d82aefee613cffc5a42ba Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 25 Oct 2023 06:26:49 +0530 Subject: [PATCH 12/12] fix: addressing reviews v3 --- .../substrait/relation/ProtoRelConverter.java | 4 +--- .../type/proto/JoinRoundtripTest.java | 24 +++---------------- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 8d5641f3..9ae2cb95 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -39,7 +39,6 @@ import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; -import java.util.stream.Stream; /** Converts from {@link io.substrait.proto.Rel} to {@link io.substrait.relation.Rel} */ public class ProtoRelConverter { @@ -501,8 +500,7 @@ private Rel newHashJoin(HashJoinRel rel) { Rel right = from(rel.getRight()); var leftKeys = rel.getLeftKeysList(); var rightKeys = rel.getRightKeysList(); - var rightOffSetKeys = - Stream.concat(leftKeys.stream(), rightKeys.stream()).collect(Collectors.toList()); + Type.Struct leftStruct = left.getRecordType(); Type.Struct rightStruct = right.getRecordType(); Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build(); diff --git a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java index c12d1cb3..2204178f 100644 --- a/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/JoinRoundtripTest.java @@ -4,14 +4,12 @@ import io.substrait.TestBase; import io.substrait.dsl.SubstraitBuilder; -import io.substrait.extension.AdvancedExtension; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; import io.substrait.relation.RelProtoConverter; import io.substrait.relation.physical.HashJoin; -import io.substrait.relation.utils.StringHolder; import io.substrait.relation.utils.StringHolderHandlingProtoRelConverter; import io.substrait.type.TypeCreator; import java.util.Arrays; @@ -43,18 +41,6 @@ public class JoinRoundtripTest extends TestBase { Arrays.asList("d", "e", "f"), Arrays.asList(R.FP64, R.STRING, R.I64)); - final AdvancedExtension commonExtension = - AdvancedExtension.builder() - .enhancement(new StringHolder("COMMON ENHANCEMENT")) - .optimization(new StringHolder("COMMON OPTIMIZATION")) - .build(); - - final AdvancedExtension relExtension = - AdvancedExtension.builder() - .enhancement(new StringHolder("REL ENHANCEMENT")) - .optimization(new StringHolder("REL OPTIMIZATION")) - .build(); - void verifyRoundTrip(Rel rel) { io.substrait.proto.Rel protoRel = relProtoConverter.toProto(rel); Rel relReturned = protoRelConverter.from(protoRel); @@ -63,15 +49,11 @@ void verifyRoundTrip(Rel rel) { @Test void hashJoin() { - List leftEmptyKeys = Arrays.asList(0, 1); - List rightEmptyKeys = Arrays.asList(2, 0); + List leftKeys = Arrays.asList(0, 1); + List rightKeys = Arrays.asList(2, 0); Rel relWithoutKeys = HashJoin.builder() - .from( - b.hashJoin( - leftEmptyKeys, rightEmptyKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) - .commonExtension(commonExtension) - .extension(relExtension) + .from(b.hashJoin(leftKeys, rightKeys, HashJoin.JoinType.INNER, leftTable, rightTable)) .build(); verifyRoundTrip(relWithoutKeys); }