Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement join operation #32

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,36 @@ def Substrait_FilterOp : Substrait_RelOp<"filter", [
}];
}

def JoinTypeKind : I32EnumAttr<"JoinTypeKind",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: move to enums file?

Copy link
Contributor Author

@dshaaban01 dshaaban01 Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually don't have an enums file. the SetOpKind Enum is also defined in this file. Should I do a PR that creates a SubstraitEnums.td file?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, OK! My long-pending PR for aggregate creates one. Let's move the other enums once that is merged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets do that 👍

"The enum values correspond to those in the JoinRel.JoinType message.", [
I32EnumAttrCase<"unspecified", 0>,
I32EnumAttrCase<"inner", 1>,
I32EnumAttrCase<"outer", 2>,
I32EnumAttrCase<"left", 3>,
I32EnumAttrCase<"right", 4>,
I32EnumAttrCase<"semi", 5>,
I32EnumAttrCase<"anti", 6>,
I32EnumAttrCase<"single", 7>] >;

def Substrait_JoinOp : Substrait_RelOp<"join", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "join operation";
let description = [{
Represents a `JoinRel` message together with the `RelCommon`, left and
right `Rel` messages and `JoinType` enumeration it contains. The current
implementation assumes the join expression to be True.
}];
//TODO(daliashaaban): Add support for join expressions.
let arguments = (ins
Substrait_Relation:$left,
Substrait_Relation:$right,
JoinTypeKind:$join_type
);
let results = (outs Substrait_Relation:$result);
let assemblyFormat = [{
$join_type $left `,` $right attr-dict `:` type($left) `,` type($right) `->` type($result)
}];
}

def Substrait_NamedTableOp : Substrait_RelOp<"named_table", [
]> {
let summary = "Read operation of a named table";
Expand Down
45 changes: 45 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,51 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
return success();
}

LogicalResult
JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
Value leftInput = operands[0];
Value rightInput = operands[1];

TypeRange leftFieldTypes = cast<TupleType>(leftInput.getType()).getTypes();
TypeRange rightFieldTypes = cast<TupleType>(rightInput.getType()).getTypes();

// get join type
Adaptor adaptor(operands, attributes, properties, regions);

JoinTypeKind join_type = adaptor.getJoinType();

SmallVector<mlir::Type> fieldTypes;

switch (join_type) {
case JoinTypeKind::unspecified:
case JoinTypeKind::inner:
case JoinTypeKind::outer:
case JoinTypeKind::right:
case JoinTypeKind::left:
llvm::append_range(fieldTypes, leftFieldTypes);
llvm::append_range(fieldTypes, rightFieldTypes);
break;
case JoinTypeKind::semi:
case JoinTypeKind::anti:
llvm::append_range(fieldTypes, leftFieldTypes);
break;
case JoinTypeKind::single:
llvm::append_range(fieldTypes, rightFieldTypes);
break;
default:
return failure();
}

auto resultType = TupleType::get(context, fieldTypes);

inferredReturnTypes = SmallVector<Type>{resultType};

return success();
}

/// Verifies that the provided field names match the provided field types. While
/// the field types are potentially nested, the names are given in a single,
/// flat list and correspond to the field types in depth first order (where each
Expand Down
44 changes: 44 additions & 0 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class SubstraitExporter {
DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression)
DECLARE_EXPORT_FUNC(FetchOp, Rel)
DECLARE_EXPORT_FUNC(FilterOp, Rel)
DECLARE_EXPORT_FUNC(JoinOp, Rel)
DECLARE_EXPORT_FUNC(LiteralOp, Expression)
DECLARE_EXPORT_FUNC(ModuleOp, Plan)
DECLARE_EXPORT_FUNC(NamedTableOp, Rel)
Expand Down Expand Up @@ -261,6 +262,48 @@ FailureOr<std::unique_ptr<Rel>> SubstraitExporter::exportOperation(EmitOp op) {
return inputRel;
}

FailureOr<std::unique_ptr<Rel>> SubstraitExporter::exportOperation(JoinOp op) {
// Build `RelCommon` message.
auto relCommon = std::make_unique<RelCommon>();
auto direct = std::make_unique<RelCommon::Direct>();
relCommon->set_allocated_direct(direct.release());

// Build `left` input message.
auto leftOp =
llvm::dyn_cast_if_present<RelOpInterface>(op.getLeft().getDefiningOp());
if (!leftOp)
return op->emitOpError(
"left input was not produced by Substrait relation op");

FailureOr<std::unique_ptr<Rel>> leftRel = exportOperation(leftOp);
if (failed(leftRel))
return failure();

// Build `right` input message.
auto rightOp =
llvm::dyn_cast_if_present<RelOpInterface>(op.getRight().getDefiningOp());
if (!rightOp)
return op->emitOpError(
"right input was not produced by Substrait relation op");

FailureOr<std::unique_ptr<Rel>> rightRel = exportOperation(rightOp);
if (failed(rightRel))
return failure();

// Build `JoinRel` message.
auto joinRel = std::make_unique<JoinRel>();
joinRel->set_allocated_common(relCommon.release());
joinRel->set_allocated_left(leftRel->release());
joinRel->set_allocated_right(rightRel->release());
joinRel->set_type(static_cast<JoinRel::JoinType>(op.getJoinType()));

// Build `Rel` message.
auto rel = std::make_unique<Rel>();
rel->set_allocated_join(joinRel.release());

return rel;
}

FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(ExpressionOpInterface op) {
return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<Expression>>>(
Expand Down Expand Up @@ -791,6 +834,7 @@ SubstraitExporter::exportOperation(RelOpInterface op) {
FetchOp,
FieldReferenceOp,
FilterOp,
JoinOp,
NamedTableOp,
ProjectOp,
SetOp
Expand Down
32 changes: 32 additions & 0 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ DECLARE_IMPORT_FUNC(SetRel, Rel, SetOp)
DECLARE_IMPORT_FUNC(Expression, Expression, ExpressionOpInterface)
DECLARE_IMPORT_FUNC(FieldReference, Expression::FieldReference,
FieldReferenceOp)
DECLARE_IMPORT_FUNC(JoinRel, Rel, JoinOp)
DECLARE_IMPORT_FUNC(Literal, Expression::Literal, LiteralOp)
DECLARE_IMPORT_FUNC(NamedTable, Rel, NamedTableOp)
DECLARE_IMPORT_FUNC(Plan, Plan, PlanOp)
Expand Down Expand Up @@ -247,6 +248,34 @@ importFieldReference(ImplicitLocOpBuilder builder,
return builder.create<FieldReferenceOp>(container, indices);
}

static mlir::FailureOr<JoinOp> importJoinRel(ImplicitLocOpBuilder builder,
const Rel &message) {
const JoinRel &joinRel = message.join();

// Import left and right inputs.
const Rel &leftRel = joinRel.left();
const Rel &rightRel = joinRel.right();

mlir::FailureOr<RelOpInterface> leftOp = importRel(builder, leftRel);
mlir::FailureOr<RelOpInterface> rightOp = importRel(builder, rightRel);

if (failed(leftOp) || failed(rightOp))
return failure();

// Build `JoinOp`.
Value leftVal = leftOp.value()->getResult(0);
Value rightVal = rightOp.value()->getResult(0);

std::optional<JoinTypeKind> join_type =
static_cast<::JoinTypeKind>(joinRel.type());

// Check for unsupported set operations.
if (!join_type)
return mlir::emitError(builder.getLoc(), "unexpected 'operation' found");

return builder.create<JoinOp>(leftVal, rightVal, *join_type);
}

static mlir::FailureOr<LiteralOp>
importLiteral(ImplicitLocOpBuilder builder,
const Expression::Literal &message) {
Expand Down Expand Up @@ -586,6 +615,9 @@ static mlir::FailureOr<RelOpInterface> importRel(ImplicitLocOpBuilder builder,
case Rel::RelTypeCase::kFilter:
maybeOp = importFilterRel(builder, message);
break;
case Rel::RelTypeCase::kJoin:
maybeOp = importJoinRel(builder, message);
break;
case Rel::RelTypeCase::kProject:
maybeOp = importProjectRel(builder, message);
break;
Expand Down
4 changes: 4 additions & 0 deletions lib/Target/SubstraitPB/ProtobufUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ FailureOr<const RelCommon *> getCommon(const Rel &rel, Location loc) {
return getCommon(rel.fetch());
case Rel::RelTypeCase::kFilter:
return getCommon(rel.filter());
case Rel::RelTypeCase::kJoin:
return getCommon(rel.join());
case Rel::RelTypeCase::kProject:
return getCommon(rel.project());
case Rel::RelTypeCase::kRead:
Expand Down Expand Up @@ -60,6 +62,8 @@ FailureOr<RelCommon *> getMutableCommon(Rel *rel, Location loc) {
return getMutableCommon((rel->mutable_fetch()));
case Rel::RelTypeCase::kFilter:
return getMutableCommon(rel->mutable_filter());
case Rel::RelTypeCase::kJoin:
return getMutableCommon(rel->mutable_join());
case Rel::RelTypeCase::kProject:
return getMutableCommon(rel->mutable_project());
case Rel::RelTypeCase::kRead:
Expand Down
152 changes: 152 additions & 0 deletions test/Dialect/Substrait/join.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// RUN: substrait-opt -split-input-file %s \
// RUN: | FileCheck %s

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join unspecified %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32, si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32, si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join unspecified %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32, si32>
yield %2 : tuple<si32, si32>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join inner %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32, si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32, si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join inner %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32, si32>
yield %2 : tuple<si32, si32>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join outer %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32, si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32, si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join outer %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32, si32>
yield %2 : tuple<si32, si32>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join left %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32, si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32, si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join left %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32, si32>
yield %2 : tuple<si32, si32>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join right %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32, si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32, si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join right %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32, si32>
yield %2 : tuple<si32, si32>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join semi %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si1>, tuple<si32> -> tuple<si1>
// CHECK-NEXT: yield %[[V2]] : tuple<si1>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join semi %0, %1 : tuple<si1>, tuple<si32> -> tuple<si1>
yield %2 : tuple<si1>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join anti %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si1>, tuple<si32> -> tuple<si1>
// CHECK-NEXT: yield %[[V2]] : tuple<si1>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si1>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = join anti %0, %1 : tuple<si1>, tuple<si32> -> tuple<si1>
yield %2 : tuple<si1>
}
}

// -----

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = join single %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si1> -> tuple<si1>
// CHECK-NEXT: yield %[[V2]] : tuple<si1>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si1>
%2 = join single %0, %1 : tuple<si32>, tuple<si1> -> tuple<si1>
yield %2 : tuple<si1>
}
}
Loading
Loading