Skip to content

Commit

Permalink
do declarative type inference for SetOp
Browse files Browse the repository at this point in the history
  • Loading branch information
dshaaban01 committed Dec 3, 2024
1 parent b688ac8 commit 9664db3
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 69 deletions.
4 changes: 2 additions & 2 deletions include/substrait-mlir/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def SetOpKind : I32EnumAttr<"SetOpKind",
I32EnumAttrCase<"union_all", 6>] >;

def Substrait_SetOp : Substrait_RelOp<"set", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
SameOperandsAndResultType
]> {
let summary = "set operation";
let description = [{
Expand All @@ -554,7 +554,7 @@ def Substrait_SetOp : Substrait_RelOp<"set", [

let results = (outs Substrait_Relation:$result);
let assemblyFormat = [{
$kind `,` $inputs attr-dict `:` type($inputs) `->` type($result)
$kind `,` $inputs attr-dict `:` type($result)
}];
}

Expand Down
27 changes: 0 additions & 27 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,33 +256,6 @@ LiteralOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
return success();
}

LogicalResult
SetOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
llvm::SmallVectorImpl<Type> &inferredReturnTypes) {

ValueRange inputs = operands;

if (inputs.size() < 2)
return ::emitError(loc.value()) << "expected at least 2 inputs";

TypeRange fieldType = cast<TupleType>(inputs[0].getType()).getTypes();

for (Value input : inputs) {
TypeRange inputFieldTypes = cast<TupleType>(input.getType()).getTypes();
if (fieldType != inputFieldTypes)
return ::emitError(loc.value())
<< "all inputs must have the same field types";
}

TypeRange fieldTypes = cast<TupleType>(inputs[0].getType()).getTypes();
auto resultType = TupleType::get(context, fieldTypes);
inferredReturnTypes = SmallVector<Type>{resultType};

return success();
}

/// Verifies that the provided field names match the provided field types. While
/// the field types are potentially nested, the names are given in a single,
/// flat list and correspond to the field types in depth first order (where each
Expand Down
12 changes: 0 additions & 12 deletions test/Dialect/Substrait/set-invalid.mlir

This file was deleted.

28 changes: 14 additions & 14 deletions test/Dialect/Substrait/set.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
// CHECK: %[[V1:.*]] = named_table
// CHECK: %[[V2:.*]] = named_table
// CHECK-NEXT: %[[V3:.*]] = set unspecified, %[[V0]], %[[V1]], %[[V2]]
// CHECK-SAME: : tuple<si32>, tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V3]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = named_table @t2 as ["c"] : tuple<si32>
%3 = set unspecified, %0, %1, %2: tuple<si32>, tuple<si32>, tuple<si32> -> tuple<si32>
%3 = set unspecified, %0, %1, %2: tuple<si32>
yield %3 : tuple<si32>
}
}
Expand All @@ -27,14 +27,14 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = set minus_primary, %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set minus_primary, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set minus_primary, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -46,14 +46,14 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = set minus_multiset, %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set minus_multiset, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set minus_multiset, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -64,14 +64,14 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = set intersection_primary, %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set intersection_primary, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set intersection_primary, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -82,14 +82,14 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = set intersection_multiset, %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set intersection_multiset, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set intersection_multiset, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -100,14 +100,14 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = set union_distinct, %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set union_distinct, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set union_distinct, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -118,14 +118,14 @@ substrait.plan version 0 : 42 : 1 {
// CHECK: %[[V0:.*]] = named_table
// CHECK: %[[V1:.*]] = named_table
// CHECK-NEXT: %[[V2:.*]] = set union_all, %[[V0]], %[[V1]]
// CHECK-SAME: : tuple<si32>, tuple<si32> -> tuple<si32>
// CHECK-SAME: : tuple<si32>
// CHECK-NEXT: yield %[[V2]] : tuple<si32>

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set union_all, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set union_all, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
14 changes: 7 additions & 7 deletions test/Target/SubstraitPB/Export/set.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set unspecified, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set unspecified, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -51,7 +51,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set minus_primary, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set minus_primary, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -75,7 +75,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set minus_multiset, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set minus_multiset, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -99,7 +99,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set intersection_primary, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set intersection_primary, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -123,7 +123,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set intersection_multiset, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set intersection_multiset, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -147,7 +147,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set union_distinct, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set union_distinct, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
Expand All @@ -171,7 +171,7 @@ substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = named_table @t2 as ["b"] : tuple<si32>
%2 = set union_all, %0, %1 : tuple<si32>, tuple<si32> -> tuple<si32>
%2 = set union_all, %0, %1 : tuple<si32>
yield %2 : tuple<si32>
}
}
14 changes: 7 additions & 7 deletions test/Target/SubstraitPB/Import/set.textpb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set unspecified, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set unspecified, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down Expand Up @@ -82,7 +82,7 @@ version {
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set minus_primary, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set minus_primary, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down Expand Up @@ -151,7 +151,7 @@ version {
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set minus_multiset, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set minus_multiset, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down Expand Up @@ -220,7 +220,7 @@ version {
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set intersection_primary, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set intersection_primary, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down Expand Up @@ -289,7 +289,7 @@ version {
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set intersection_multiset, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set intersection_multiset, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down Expand Up @@ -358,7 +358,7 @@ version {
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set union_distinct, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set union_distinct, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down Expand Up @@ -427,7 +427,7 @@ version {
# CHECK-NEXT: relation {
# CHECK-NEXT: %[[V0:.*]] = named_table
# CHECK-NEXT: %[[V1:.*]] = named_table
# CHECK-NEXT: %[[V2:.*]] = set union_all, %[[V0]], %[[V1]] : tuple<si32>, tuple<si32> -> tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = set union_all, %[[V0]], %[[V1]] : tuple<si32>
# CHECK-NEXT: yield %[[V2]] : tuple<si32>

relations {
Expand Down

0 comments on commit 9664db3

Please sign in to comment.