Skip to content

Commit

Permalink
[MHLO] Pad constant mode & GatherElements to MHLO
Browse files Browse the repository at this point in the history
Signed-off-by: chongsong.chen <[email protected]>
  • Loading branch information
chenchongsong committed Jul 17, 2023
1 parent defe402 commit 18e93f8
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/Conversion/ONNXToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ add_onnx_mlir_library(OMONNXToMhlo
Tensor/Expand.cpp
Tensor/Flatten.cpp
Tensor/Gather.cpp
Tensor/GatherElements.cpp
Tensor/Identity.cpp
Tensor/Pad.cpp
Tensor/Reshape.cpp
Tensor/Shape.cpp
Tensor/Slice.cpp
Expand Down
5 changes: 4 additions & 1 deletion src/Conversion/ONNXToMhlo/ConvertONNXToMhlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ void populateONNXToMhloConversionPattern(
populateLoweringONNXExpandOpToMhloPattern(patterns, ctx);
populateLoweringONNXFlattenOpToMhloPattern(patterns, ctx);
populateLoweringONNXGatherOpToMhloPattern(patterns, ctx);
populateLoweringONNXGatherElementsOpToMhloPattern(patterns, ctx);
populateLoweringONNXIdentityOpToMhloPattern(patterns, ctx);
populateLoweringONNXPadOpToMhloPattern(patterns, ctx);
populateLoweringONNXReshapeOpToMhloPattern(patterns, ctx);
populateLoweringONNXShapeOpToMhloPattern(patterns, ctx);
populateLoweringONNXSliceOpToMhloPattern(patterns, ctx);
Expand Down Expand Up @@ -89,7 +91,8 @@ void FrontendToMhloLoweringPass::runOnOperation() {
// Added affine as some affine maps are generated by IndexExpression. It could
// be disabled and/or replaced by shape max/min.
target.addLegalDialect<mhlo::MhloDialect, func::FuncDialect,
arith::ArithDialect, shape::ShapeDialect, mlir::affine::AffineDialect>();
arith::ArithDialect, shape::ShapeDialect, mlir::affine::AffineDialect,
tensor::TensorDialect>();
// Needed to support unsigned int computations. To be removed if we use a
// scheme that does not rely on the UnrealizedConversionCastOp.
target.addLegalOp<::mlir::UnrealizedConversionCastOp>();
Expand Down
13 changes: 13 additions & 0 deletions src/Conversion/ONNXToMhlo/ONNXToMhloCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,17 @@ llvm::SmallVector<Value, 4> getBroadcastedOperands(
}
return broadcastedOperands;
}

ElementsAttr getElementAttributeFromMhloValue(Value value) {
auto definingOp = value.getDefiningOp();
if (auto constantOp = dyn_cast_or_null<mhlo::ConstantOp>(definingOp))
return constantOp.getValue().dyn_cast<ElementsAttr>();
else if (auto constantOp =
dyn_cast_or_null<mlir::ONNXConstantOp>(definingOp)) {
if (constantOp.getValue().has_value())
return constantOp.getValueAttr().dyn_cast<ElementsAttr>();
}
return nullptr;
}

} // namespace onnx_mlir
6 changes: 6 additions & 0 deletions src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -113,6 +114,8 @@ llvm::SmallVector<Value, 4> getBroadcastedOperands(
llvm::SmallVector<Value, 4> &operands, Type outputType,
ConversionPatternRewriter &rewriter, Location loc, int64_t outputRank);

mlir::ElementsAttr getElementAttributeFromMhloValue(mlir::Value value);

// `Math` directory methods:
void populateLoweringONNXClipOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
Expand Down Expand Up @@ -148,8 +151,11 @@ void populateLoweringONNXFlattenOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXGatherOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXGatherElementsOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXIdentityOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXPadOpToMhloPattern(RewritePatternSet &, MLIRContext *);
void populateLoweringONNXReshapeOpToMhloPattern(
RewritePatternSet &, MLIRContext *);
void populateLoweringONNXShapeOpToMhloPattern(
Expand Down
139 changes: 139 additions & 0 deletions src/Conversion/ONNXToMhlo/Tensor/GatherElements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===-------- GatherElements.cpp - Lowering GatherElements Op -------------===//
//
// Copyright 2020-2022 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX GatherElements Operator to Mhlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToMhlo/DialectBuilder.hpp"
#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXGatherElementsOpLoweringToMhlo : public ConversionPattern {
ONNXGatherElementsOpLoweringToMhlo(MLIRContext *ctx)
: ConversionPattern(
mlir::ONNXGatherElementsOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXGatherElementsOpAdaptor operandAdaptor(operands);
ONNXGatherElementsOp gatherOp = cast<ONNXGatherElementsOp>(op);
Location loc = op->getLoc();

IndexExprBuilderForMhlo createIE(rewriter, loc);
ONNXGatherElementsOpShapeHelper shapeHelper(op, operands, &createIE);
shapeHelper.computeShapeAndAssertOnFailure();

Type outputType = *op->result_type_begin();
assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType");

// Operands and attributes.
Value data = operandAdaptor.getData();
Value indices = operandAdaptor.getIndices();
int64_t axisLit = gatherOp.getAxis();

ShapedType inputType = data.getType().cast<ShapedType>();
int64_t rank = inputType.getRank(); // indices has the same rank
ShapedType indicesType = indices.getType().cast<ShapedType>();
Type indexElemType = indicesType.getElementType();
// Negative value means counting dimensions from the back.
axisLit = axisLit < 0 ? axisLit + rank : axisLit;

// make sure all index values >= 0
Value zero = getShapedZero(loc, rewriter, indices);
Value inputShape = rewriter.create<shape::ShapeOfOp>(loc, data);
Value indicesShape = rewriter.create<shape::ShapeOfOp>(loc, indices);
Value axisDimSize =
rewriter.create<shape::GetExtentOp>(loc, inputShape, axisLit);
axisDimSize =
rewriter.create<arith::IndexCastOp>(loc, indexElemType, axisDimSize);
axisDimSize = rewriter.create<tensor::FromElementsOp>(loc, axisDimSize);
axisDimSize = rewriter.create<mhlo::ReshapeOp>(loc,
RankedTensorType::get(SmallVector<int64_t>{}, indexElemType),
axisDimSize);
Value broadcastedAxisDimSize =
rewriter.create<mhlo::DynamicBroadcastInDimOp>(loc, indicesType,
axisDimSize, indicesShape, rewriter.getI64TensorAttr({}));
Value isNegative = rewriter.create<mhlo::CompareOp>(
loc, indices, zero, mhlo::ComparisonDirection::LT);
Value positiveIndices = rewriter.create<mhlo::AddOp>(
loc, indicesType, indices, broadcastedAxisDimSize);
indices = rewriter.create<mhlo::SelectOp>(
loc, indicesType, isNegative, positiveIndices, indices);

// start indices
Value toConcatIndexShape;
SmallVector<Value> toConcatIndexShapeValueVec;
for (size_t i = 0; i < rank; i++) {
toConcatIndexShapeValueVec.push_back(
rewriter.create<shape::GetExtentOp>(loc, indicesShape, i));
}
toConcatIndexShapeValueVec.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, 1));
toConcatIndexShape = rewriter.create<tensor::FromElementsOp>(
loc, toConcatIndexShapeValueVec);

ArrayRef<int64_t> indicesShapeVec = indicesType.getShape();
SmallVector<int64_t> toConcatIndexShapeVec(
indicesShapeVec.begin(), indicesShapeVec.end());
toConcatIndexShapeVec.push_back(1);
RankedTensorType toConcatIndexType =
RankedTensorType::get(toConcatIndexShapeVec, indexElemType);

SmallVector<Value> toConcat;
for (int64_t i = 0; i < inputType.getRank(); ++i) {
if (i == axisLit) {
toConcat.push_back(rewriter.create<mhlo::DynamicReshapeOp>(
loc, toConcatIndexType, indices, toConcatIndexShape));
} else {
toConcat.push_back(
rewriter.create<mhlo::DynamicIotaOp>(loc, toConcatIndexType,
toConcatIndexShape, rewriter.getI64IntegerAttr(i)));
}
}
auto gatherIndicies = rewriter.create<mhlo::ConcatenateOp>(
loc, toConcat, static_cast<uint64_t>(inputType.getRank()));

// dimsAttr
SmallVector<int64_t> collapsedDims;
SmallVector<int64_t> startIndexMap;
for (int64_t i = 0; i < rank; i++) {
collapsedDims.push_back(i);
startIndexMap.push_back(i);
}
auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get(rewriter.getContext(),
/*offsetDims=*/{},
/*collapsedSliceDims=*/collapsedDims,
/*startIndexMap=*/startIndexMap,
/*indexVecDim=*/rank);
SmallVector<int64_t> sliceSizes(inputType.getRank(), 1);

Value gatherValue = rewriter.create<mhlo::GatherOp>(loc, outputType, data,
gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes));
rewriter.replaceOp(op, gatherValue);
return success();
}
};

} // namespace

void populateLoweringONNXGatherElementsOpToMhloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXGatherElementsOpLoweringToMhlo>(ctx);
}

} // namespace onnx_mlir
99 changes: 99 additions & 0 deletions src/Conversion/ONNXToMhlo/Tensor/Pad.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===----------- Pad.cpp - Lowering Pad Op ------------===//
//
// Copyright 2022
//
// =============================================================================
//
// This file lowers ONNX Pad Operators to Mhlo dialect.
//
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToMhlo/ONNXToMhloCommon.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Support/TypeUtilities.hpp"

using namespace mlir;

namespace onnx_mlir {

namespace {

struct ONNXPadOpLoweringToMhlo : public ConversionPattern {
ONNXPadOpLoweringToMhlo(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXPadOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {

Location loc = op->getLoc();
ONNXPadOpAdaptor operandAdaptor(operands, op->getAttrDictionary());
Value data = operandAdaptor.getData();
Value constantValue = operandAdaptor.getConstantValue();
Value pads = operandAdaptor.getPads();
StringRef padMode = operandAdaptor.getMode();

if (!padMode.equals_insensitive("constant"))
return failure();
assert(isRankedShapedType(data.getType()) && "Expected Ranked ShapedType");
ShapedType inputType = data.getType().cast<ShapedType>();
Type elemType = inputType.getElementType();
int64_t rank = inputType.getRank();

Type outputType = *op->result_type_begin();
if (!constantValue || isNoneValue(constantValue)) {
// Pad with zeros by default
constantValue = rewriter.create<mhlo::ConstantOp>(
loc, DenseElementsAttr::get(mlir::RankedTensorType::get({}, elemType),
rewriter.getZeroAttr(elemType)));
} else {
// constantValue might be 1D tensor, reshape it to scalar
constantValue = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, elemType), constantValue);
}
SmallVector<int64_t> edgePaddingLowVec(rank, 0);
SmallVector<int64_t> edgePaddingHighVec(rank, 0);
SmallVector<int64_t> interiorPaddingVec(rank, 0);
if (auto valueAttribute = getElementAttributeFromMhloValue(pads)) {
// If `pads` are constants, read them."
int64_t idx = 0;
for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) {
int64_t padValue = value.getInt();
if (padValue < 0)
return failure();
if (idx < rank)
edgePaddingLowVec[idx] = padValue;
else
edgePaddingHighVec[idx - rank] = padValue;
idx++;
}
} else {
assert(false && "Pads must be known at compile time");
}

mlir::DenseIntElementsAttr edgePaddingLow =
rewriter.getI64VectorAttr(edgePaddingLowVec);
mlir::DenseIntElementsAttr edgePaddingHigh =
rewriter.getI64VectorAttr(edgePaddingHighVec);
mlir::DenseIntElementsAttr interiorPadding =
rewriter.getI64VectorAttr(interiorPaddingVec);
Value padResult = rewriter.create<mhlo::PadOp>(loc, outputType, data,
constantValue, edgePaddingLow, edgePaddingHigh, interiorPadding);

rewriter.replaceOp(op, padResult);
return success();
}
};

} // namespace

void populateLoweringONNXPadOpToMhloPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXPadOpLoweringToMhlo>(ctx);
}

} // namespace onnx_mlir
18 changes: 18 additions & 0 deletions test/mlir/conversion/onnx_to_mhlo/Tensor/GatherElements.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s --canonicalize -split-input-file | FileCheck %s

func.func @main_gather_elements(%arg0: tensor<3x2xf32>, %arg1: tensor<2x2xi64>) -> tensor<2x2xf32> {
%0 = "onnx.GatherElements"(%arg0, %arg1) {axis = 0 : si64} : (tensor<3x2xf32>, tensor<2x2xi64>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
// CHECK: func.func @main_gather_elements([[PARAM_0_:%.+]]: tensor<3x2xf32>, [[PARAM_1_:%.+]]: tensor<2x2xi64>) -> tensor<2x2xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = mhlo.constant dense<3> : tensor<2x2xi64>
// CHECK-DAG: [[VAR_1_:%.+]] = mhlo.constant dense<0> : tensor<2x2xi64>
// CHECK-DAG: [[VAR_2_:%.+]] = mhlo.compare LT, [[PARAM_1_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1>
// CHECK-DAG: [[VAR_3_:%.+]] = mhlo.add [[PARAM_1_]], [[VAR_0_]] : tensor<2x2xi64>
// CHECK-NEXT: [[VAR_4_:%.+]] = mhlo.select [[VAR_2_]], [[VAR_3_]], [[PARAM_1_]] : tensor<2x2xi1>, tensor<2x2xi64>
// CHECK-NEXT: [[VAR_5_:%.+]] = mhlo.reshape [[VAR_4_]] : (tensor<2x2xi64>) -> tensor<2x2x1xi64>
// CHECK-DAG: [[VAR_6_:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2xi64>
// CHECK-NEXT: [[VAR_7_:%.+]] = "mhlo.broadcast_in_dim"([[VAR_6_]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<2x2x1xi64>
// CHECK-NEXT: [[VAR_8_:%.+]] = "mhlo.concatenate"([[VAR_5_]], [[VAR_7_]]) {dimension = 2 : i64} : (tensor<2x2x1xi64>, tensor<2x2x1xi64>) -> tensor<2x2x2xi64>
// CHECK-NEXT: [[VAR_9_:%.+]] = "mhlo.gather"([[PARAM_0_]], [[VAR_8_]]) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 2>, indices_are_sorted = false, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<3x2xf32>, tensor<2x2x2xi64>) -> tensor<2x2xf32>
// CHECK-NEXT: return [[VAR_9_]] : tensor<2x2xf32>
}
13 changes: 13 additions & 0 deletions test/mlir/conversion/onnx_to_mhlo/Tensor/Pad.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: onnx-mlir-opt --convert-onnx-to-mhlo %s --canonicalize -split-input-file | FileCheck %s

func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> {
%0 = onnx.Constant dense<[0, 0, 1, 1, 0, 0, 1, 1]> : tensor<8xi64>
%1 = onnx.Constant dense<2.000000e+00> : tensor<f32>
%2 = "onnx.NoValue"() {value} : () -> none
%3 = "onnx.Pad"(%arg0, %0, %1, %2) {mode = "constant"} : (tensor<1x3x5x5xf32>, tensor<8xi64>, tensor<f32>, none) -> tensor<1x3x7x7xf32>
return %3 : tensor<1x3x7x7xf32>
// CHECK-LABEL: func.func @test_pad_constant(%arg0: tensor<1x3x5x5xf32>) -> tensor<1x3x7x7xf32> {
// CHECK-NEXT: %0 = mhlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK-NEXT: %1 = "mhlo.pad"(%arg0, %0) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<[0, 0, 1, 1]> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x3x5x5xf32>, tensor<f32>) -> tensor<1x3x7x7xf32>
// CHECK-NEXT: return %1 : tensor<1x3x7x7xf32>
}

0 comments on commit 18e93f8

Please sign in to comment.