Skip to content

Commit

Permalink
feat: include protorel conerter
Browse files Browse the repository at this point in the history
  • Loading branch information
vibhatha committed Oct 10, 2023
1 parent 334da34 commit c43cd2b
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.substrait.proto.ExtensionSingleRel;
import io.substrait.proto.FetchRel;
import io.substrait.proto.FilterRel;
import io.substrait.proto.HashJoinRel;
import io.substrait.proto.JoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
Expand All @@ -27,6 +28,7 @@
import io.substrait.relation.files.FileOrFiles;
import io.substrait.relation.files.ImmutableFileFormat;
import io.substrait.relation.files.ImmutableFileOrFiles;
import io.substrait.relation.physical.HashJoin;
import io.substrait.type.ImmutableNamedStruct;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
Expand Down Expand Up @@ -95,6 +97,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case EXTENSION_MULTI -> {
return newExtensionMulti(rel.getExtensionMulti());
}
case HASH_JOIN -> {
return newHashJoin(rel.getHashJoin());
}
default -> {
throw new UnsupportedOperationException("Unsupported RelTypeCase of " + relType);
}
Expand Down Expand Up @@ -490,6 +495,35 @@ private Set newSet(SetRel rel) {
return builder.build();
}

private Rel newHashJoin(HashJoinRel rel) {
Rel left = from(rel.getLeft());
Rel right = from(rel.getRight());
Type.Struct leftStruct = left.getRecordType();
Type.Struct rightStruct = right.getRecordType();
Type.Struct unionedStruct = Type.Struct.builder().from(leftStruct).from(rightStruct).build();
var converter = new ProtoExpressionConverter(lookup, extensions, unionedStruct, this);
var builder =
HashJoin.builder()
.left(left)
.right(right)
.leftKeys(
rel.getLeftKeysList().stream().map(converter::from).collect(Collectors.toList()))
.rightKeys(
rel.getRightKeysList().stream().map(converter::from).collect(Collectors.toList()))
.joinType(HashJoin.JoinType.fromProto(rel.getType()))
.postJoinFilter(
Optional.ofNullable(
rel.hasPostJoinFilter() ? converter.from(rel.getPostJoinFilter()) : null));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()));
if (rel.hasAdvancedExtension()) {
builder.extension(advancedExtension(rel.getAdvancedExtension()));
}
return builder.build();
}

private static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
Expand Down

0 comments on commit c43cd2b

Please sign in to comment.