-
Notifications
You must be signed in to change notification settings - Fork 320
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MHLO] Pad constant mode & GatherElements to MHLO
Signed-off-by: chongsong.chen <[email protected]>
- Loading branch information
1 parent
defe402
commit 18e93f8
Showing
8 changed files
with
294 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===-------- GatherElements.cpp - Lowering GatherElements Op -------------===// | ||
// | ||
// Copyright 2020-2022 The IBM Research Authors. | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers the ONNX GatherElements Operator to Mhlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToMhlo/DialectBuilder.hpp" | ||
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" | ||
#include "src/Support/TypeUtilities.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXGatherElementsOpLoweringToMhlo : public ConversionPattern { | ||
ONNXGatherElementsOpLoweringToMhlo(MLIRContext *ctx) | ||
: ConversionPattern( | ||
mlir::ONNXGatherElementsOp::getOperationName(), 1, ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
ONNXGatherElementsOpAdaptor operandAdaptor(operands); | ||
ONNXGatherElementsOp gatherOp = cast<ONNXGatherElementsOp>(op); | ||
Location loc = op->getLoc(); | ||
|
||
IndexExprBuilderForMhlo createIE(rewriter, loc); | ||
ONNXGatherElementsOpShapeHelper shapeHelper(op, operands, &createIE); | ||
shapeHelper.computeShapeAndAssertOnFailure(); | ||
|
||
Type outputType = *op->result_type_begin(); | ||
assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); | ||
|
||
// Operands and attributes. | ||
Value data = operandAdaptor.getData(); | ||
Value indices = operandAdaptor.getIndices(); | ||
int64_t axisLit = gatherOp.getAxis(); | ||
|
||
ShapedType inputType = data.getType().cast<ShapedType>(); | ||
int64_t rank = inputType.getRank(); // indices has the same rank | ||
ShapedType indicesType = indices.getType().cast<ShapedType>(); | ||
Type indexElemType = indicesType.getElementType(); | ||
// Negative value means counting dimensions from the back. | ||
axisLit = axisLit < 0 ? axisLit + rank : axisLit; | ||
|
||
// make sure all index values >= 0 | ||
Value zero = getShapedZero(loc, rewriter, indices); | ||
Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, data); | ||
Value indicesShape = rewriter.create<shape::ShapeOfOp>(loc, indices); | ||
Value axisDimSize = | ||
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit); | ||
axisDimSize = | ||
rewriter.create<arith::IndexCastOp>(loc, indexElemType, axisDimSize); | ||
axisDimSize = rewriter.create<tensor::FromElementsOp>(loc, axisDimSize); | ||
axisDimSize = rewriter.create<mhlo::ReshapeOp>(loc, | ||
RankedTensorType::get(SmallVector<int64_t>{}, indexElemType), | ||
axisDimSize); | ||
Value broadcastedAxisDimSize = | ||
rewriter.create<mhlo::DynamicBroadcastInDimOp>(loc, indicesType, | ||
axisDimSize, indicesShape, rewriter.getI64TensorAttr({})); | ||
Value isNegative = rewriter.create<mhlo::CompareOp>( | ||
loc, indices, zero, mhlo::ComparisonDirection::LT); | ||
Value positiveIndices = rewriter.create<mhlo::AddOp>( | ||
loc, indicesType, indices, broadcastedAxisDimSize); | ||
indices = rewriter.create<mhlo::SelectOp>( | ||
loc, indicesType, isNegative, positiveIndices, indices); | ||
|
||
// start indices | ||
Value toConcatIndexShape; | ||
SmallVector<Value> toConcatIndexShapeValueVec; | ||
for (size_t i = 0; i < rank; i++) { | ||
toConcatIndexShapeValueVec.push_back( | ||
rewriter.create<shape::GetExtentOp>(loc, indicesShape, i)); | ||
} | ||
toConcatIndexShapeValueVec.push_back( | ||
rewriter.create<arith::ConstantIndexOp>(loc, 1)); | ||
toConcatIndexShape = rewriter.create<tensor::FromElementsOp>( | ||
loc, toConcatIndexShapeValueVec); | ||
|
||
ArrayRef<int64_t> indicesShapeVec = indicesType.getShape(); | ||
SmallVector<int64_t> toConcatIndexShapeVec( | ||
indicesShapeVec.begin(), indicesShapeVec.end()); | ||
toConcatIndexShapeVec.push_back(1); | ||
RankedTensorType toConcatIndexType = | ||
RankedTensorType::get(toConcatIndexShapeVec, indexElemType); | ||
|
||
SmallVector<Value> toConcat; | ||
for (int64_t i = 0; i < inputType.getRank(); ++i) { | ||
if (i == axisLit) { | ||
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>( | ||
loc, toConcatIndexType, indices, toConcatIndexShape)); | ||
} else { | ||
toConcat.push_back( | ||
rewriter.create<mhlo::DynamicIotaOp>(loc, toConcatIndexType, | ||
toConcatIndexShape, rewriter.getI64IntegerAttr(i))); | ||
} | ||
} | ||
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>( | ||
loc, toConcat, static_cast<uint64_t>(inputType.getRank())); | ||
|
||
// dimsAttr | ||
SmallVector<int64_t> collapsedDims; | ||
SmallVector<int64_t> startIndexMap; | ||
for (int64_t i = 0; i < rank; i++) { | ||
collapsedDims.push_back(i); | ||
startIndexMap.push_back(i); | ||
} | ||
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(rewriter.getContext(), | ||
/*offsetDims=*/{}, | ||
/*collapsedSliceDims=*/collapsedDims, | ||
/*startIndexMap=*/startIndexMap, | ||
/*indexVecDim=*/rank); | ||
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1); | ||
|
||
Value gatherValue = rewriter.create<mhlo::GatherOp>(loc, outputType, data, | ||
gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); | ||
rewriter.replaceOp(op, gatherValue); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXGatherElementsOpToMhloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXGatherElementsOpLoweringToMhlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
//===----------- Pad.cpp - Lowering Pad Op ------------===// | ||
// | ||
// Copyright 2022 | ||
// | ||
// ============================================================================= | ||
// | ||
// This file lowers ONNX Pad Operators to Mhlo dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp" | ||
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp" | ||
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp" | ||
#include "src/Support/TypeUtilities.hpp" | ||
|
||
using namespace mlir; | ||
|
||
namespace onnx_mlir { | ||
|
||
namespace { | ||
|
||
struct ONNXPadOpLoweringToMhlo : public ConversionPattern { | ||
ONNXPadOpLoweringToMhlo(MLIRContext *ctx) | ||
: ConversionPattern(mlir::ONNXPadOp::getOperationName(), 1, ctx) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||
ConversionPatternRewriter &rewriter) const final { | ||
|
||
Location loc = op->getLoc(); | ||
ONNXPadOpAdaptor operandAdaptor(operands, op->getAttrDictionary()); | ||
Value data = operandAdaptor.getData(); | ||
Value constantValue = operandAdaptor.getConstantValue(); | ||
Value pads = operandAdaptor.getPads(); | ||
StringRef padMode = operandAdaptor.getMode(); | ||
|
||
if (!padMode.equals_insensitive("constant")) | ||
return failure(); | ||
assert(isRankedShapedType(data.getType()) && "Expected Ranked ShapedType"); | ||
ShapedType inputType = data.getType().cast<ShapedType>(); | ||
Type elemType = inputType.getElementType(); | ||
int64_t rank = inputType.getRank(); | ||
|
||
Type outputType = *op->result_type_begin(); | ||
if (!constantValue || isNoneValue(constantValue)) { | ||
// Pad with zeros by default | ||
constantValue = rewriter.create<mhlo::ConstantOp>( | ||
loc, DenseElementsAttr::get(mlir::RankedTensorType::get({}, elemType), | ||
rewriter.getZeroAttr(elemType))); | ||
} else { | ||
// constantValue might be 1D tensor, reshape it to scalar | ||
constantValue = rewriter.create<mhlo::ReshapeOp>( | ||
loc, RankedTensorType::get({}, elemType), constantValue); | ||
} | ||
SmallVector<int64_t> edgePaddingLowVec(rank, 0); | ||
SmallVector<int64_t> edgePaddingHighVec(rank, 0); | ||
SmallVector<int64_t> interiorPaddingVec(rank, 0); | ||
if (auto valueAttribute = getElementAttributeFromMhloValue(pads)) { | ||
// If `pads` are constants, read them." | ||
int64_t idx = 0; | ||
for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) { | ||
int64_t padValue = value.getInt(); | ||
if (padValue < 0) | ||
return failure(); | ||
if (idx < rank) | ||
edgePaddingLowVec[idx] = padValue; | ||
else | ||
edgePaddingHighVec[idx - rank] = padValue; | ||
idx++; | ||
} | ||
} else { | ||
assert(false && "Pads must be known at compile time"); | ||
} | ||
|
||
mlir::DenseIntElementsAttr edgePaddingLow = | ||
rewriter.getI64VectorAttr(edgePaddingLowVec); | ||
mlir::DenseIntElementsAttr edgePaddingHigh = | ||
rewriter.getI64VectorAttr(edgePaddingHighVec); | ||
mlir::DenseIntElementsAttr interiorPadding = | ||
rewriter.getI64VectorAttr(interiorPaddingVec); | ||
Value padResult = rewriter.create<mhlo::PadOp>(loc, outputType, data, | ||
constantValue, edgePaddingLow, edgePaddingHigh, interiorPadding); | ||
|
||
rewriter.replaceOp(op, padResult); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void populateLoweringONNXPadOpToMhloPattern( | ||
RewritePatternSet &patterns, MLIRContext *ctx) { | ||
patterns.insert<ONNXPadOpLoweringToMhlo>(ctx); | ||
} | ||
|
||
} // namespace onnx_mlir |
18 changes: 18 additions & 0 deletions
18
test/mlir/conversion/onnx_to_mhlo/Tensor/GatherElements.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s --canonicalize -split-input-file | FileCheck %s | ||
|
||
func.func @main_gather_elements(%arg0: tensor<3x2xf32>, %arg1: tensor<2x2xi64>) -> tensor<2x2xf32> { | ||
%0 = "onnx.GatherElements"(%arg0, %arg1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32> | ||
return %0 : tensor<2x2xf32> | ||
// CHECK: func.func @main_gather_elements([[PARAM_0_:%.+]]: tensor<3x2xf32>, [[PARAM_1_:%.+]]: tensor<2x2xi64>) -> tensor<2x2xf32> { | ||
// CHECK-DAG: [[VAR_0_:%.+]] = mhlo.constant dense<3> : tensor<2x2xi64> | ||
// CHECK-DAG: [[VAR_1_:%.+]] = mhlo.constant dense<0> : tensor<2x2xi64> | ||
// CHECK-DAG: [[VAR_2_:%.+]] = mhlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1> | ||
// CHECK-DAG: [[VAR_3_:%.+]] = mhlo.add [[PARAM_1_]], [[VAR_0_]] : tensor<2x2xi64> | ||
// CHECK-NEXT: [[VAR_4_:%.+]] = mhlo.select [[VAR_2_]], [[VAR_3_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64> | ||
// CHECK-NEXT: [[VAR_5_:%.+]] = mhlo.reshape [[VAR_4_]] : (tensor<2x2xi64>) -> tensor<2x2x1xi64> | ||
// CHECK-DAG: [[VAR_6_:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi64> | ||
// CHECK-NEXT: [[VAR_7_:%.+]] = "mhlo.broadcast_in_dim"([[VAR_6_]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<2x2x1xi64> | ||
// CHECK-NEXT: [[VAR_8_:%.+]] = "mhlo.concatenate"([[VAR_5_]], [[VAR_7_]]) {dimension = 2 : i64} : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64> | ||
// CHECK-NEXT: [[VAR_9_:%.+]] = "mhlo.gather"([[PARAM_0_]], [[VAR_8_]]) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<3x2xf32>, tensor<2x2x2xi64>) -> tensor<2x2xf32> | ||
// CHECK-NEXT: return [[VAR_9_]] : tensor<2x2xf32> | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s --canonicalize -split-input-file | FileCheck %s | ||
|
||
func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> { | ||
%0 = onnx.Constant dense<[0, 0, 1, 1, 0, 0, 1, 1]> : tensor<8xi64> | ||
%1 = onnx.Constant dense<2.000000e+00> : tensor<f32> | ||
%2 = "onnx.NoValue"() {value} : () -> none | ||
%3 = "onnx.Pad"(%arg0, %0, %1, %2) {mode = "constant"} : (tensor<1x3x5x5xf32>, tensor<8xi64>, tensor<f32>, none) -> tensor<1x3x7x7xf32> | ||
return %3 : tensor<1x3x7x7xf32> | ||
// CHECK-LABEL: func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> { | ||
// CHECK-NEXT: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32> | ||
// CHECK-NEXT: %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<[0, 0, 1, 1]> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3x7x7xf32> | ||
// CHECK-NEXT: return %1 : tensor<1x3x7x7xf32> | ||
} |