Skip to content

Commit

Permalink
Merge branch 'main' into pr_support_dynamic_indices_in_gather_scatter_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
negiyas committed Sep 14, 2023
2 parents e174c05 + a54a48c commit cc4b350
Show file tree
Hide file tree
Showing 19 changed files with 647 additions and 74 deletions.
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 19. Limitatio
| **Asinh** |9 - * | | |
| **Atan** |7 - * | | |
| **Atanh** |9 - * | | |
| **AveragePool** |6 - 18 | | |
| **AveragePool** |6 - * | | |
| **BatchNormalization** |6 - * |Training not supported. | |
| **Bernoulli** |none | | | |
| **Binarizer** |none | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def GetI64ArrayAttrStridesAveragePool: NativeCodeCall<
"($0.getDefiningOp<ONNXAveragePoolOp>()))">;

def replaceONNXAveragePoolPattern : Pattern<
(ONNXAveragePoolOp:$res $x, $_, $_, $_, $_, $_, $_),
(ONNXAveragePoolOp:$res $x, $_, $_, $_, $_, $_, $_, $_),
[
// Get attributes using shape helper
(GetStrAttrPaddingtypeAveragePool:$padtype $res),
Expand Down
11 changes: 11 additions & 0 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern<ZLowStickOp> {
ZLowStickOp stickOp, PatternRewriter &rewriter) const override {
Value stickInput = stickOp.getX();

// Do not handle NCHW layout stickification that transposes data
// internally.
std::string stickLayout = stickOp.getLayout().value().str();
if (stickLayout == LAYOUT_NCHW)
return failure();

// Input is a block argument, ignore it.
if (stickInput.dyn_cast<BlockArgument>())
return failure();
Expand All @@ -157,6 +163,11 @@ class StickViewUnstickRemovalPattern : public OpRewritePattern<ZLowStickOp> {
ZLowUnstickOp userOp = llvm::dyn_cast<ZLowUnstickOp>(user);
if (!userOp)
continue;
// Do not handle NCHW layout stickification that transposes data
// internally.
std::string unstickLayout = userOp.getLayout().value().str();
if (unstickLayout == LAYOUT_NCHW)
continue;
// UnstickOp must be before the view operation.
if (userOp.getOut() == viewSource &&
user->isBeforeInBlock(viewOp.getOperation())) {
Expand Down
5 changes: 4 additions & 1 deletion src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,10 @@ class FrontendGenImpl {
}

for (auto &name : functionProto.output()) {
outputs.push_back(LookupOnnxName(name));
// Skip missing optional outputs: they are not mapped.
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(name)) {
outputs.push_back(*valuePtr);
}
}

frontend_symbols_.popScope(scopeName);
Expand Down
4 changes: 2 additions & 2 deletions src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ op_dialect_version_map_["Asin"] = {7};
op_dialect_version_map_["Asinh"] = {9};
op_dialect_version_map_["Atan"] = {7};
op_dialect_version_map_["Atanh"] = {9};
op_dialect_version_map_["AveragePool"] = {11};
op_dialect_version_map_["AveragePool"] = {19};
op_dialect_version_map_["BatchNormalization"] = {15};
op_dialect_version_map_["Bernoulli"] = {15};
op_dialect_version_map_["Binarizer"] = {1};
Expand Down Expand Up @@ -157,7 +157,7 @@ op_dialect_version_map_["ReduceSum"] = {13, 11};
op_dialect_version_map_["ReduceSumSquare"] = {18, 13};
op_dialect_version_map_["Relu"] = {14};
op_dialect_version_map_["Reshape"] = {19};
op_dialect_version_map_["Resize"] = {18, 13, 11, 10};
op_dialect_version_map_["Resize"] = {19, 13, 11, 10};
op_dialect_version_map_["ReverseSequence"] = {10};
op_dialect_version_map_["RoiAlign"] = {16};
op_dialect_version_map_["Round"] = {11};
Expand Down
58 changes: 54 additions & 4 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
CheckIfCustomScalarOpIsSupported<ONNXModOp>(elementType);
Value dividend = scalarOperands[0];
Value divisor = scalarOperands[1];
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
MultiDialectBuilder<MathBuilder, KrnlBuilder> create(rewriter, loc);

// TODO: here we assume fmod=1, what should if that is not the case?
if (create.math.isFloatWithVector(elementType)) {
Expand All @@ -1136,9 +1136,59 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
#endif
}
if (create.math.isIntegerWithVector(elementType)) {
// TODO: implement
llvm_unreachable("not support integers at this moment since MLIR integers "
"are signless.");
// "math.rem" returns "minus" for minus dividend and "plus or zero" for plus
// dividend. We call the math.rem's return value "mathRemaider". However
// onnx.ModOp should return "minus" for minus divisor and "plus or zero" for
// plus divisor. we call the value that onnx.Mod op should return "onnxMod".
// The following table shows mathRemainder, onnxMod and their diference
// (=onnxMod-mathRemainder) for some inputs.
//
// dividend | 7 | 7 | -7 | -7 | 6 | 6 | -6 | -6 |
// divisor | 3 | -3 | 3 | -3 | 3 | -3 | 3 | -3 |
// ------------------------+-----+----+----+----+----+----+----+----+
// mathRemainder | 1 | 1 | -1 | -1 | 0 | 0 | 0 | 0 |
// onnxMod | 1 | -2 | 2 | -1 | 0 | 0 | 0 | 0 |
// onnxMod - mathRemainder | 0 | -3 | 3 | 0 | 0 | 0 | 0 | 0 |
//
// The following code shows logic to get onnxMod from mathRemainder
//
// int dividend, divisor;
// int mathRemainder = diviend % divisor;
// int adjustedRemainder = mathRemainder + divisor;
//
// if ((mathRemainder != 0) && ((dividend < 0) ^ (divisor < 0))) # c.f. "^"
// shows "exclusive or".
// return adjustedRemainder;
// return mathRemainder;

Value mathRemainder = create.math.rem(dividend, divisor);
Value adjustedRemainder = create.math.add(mathRemainder, divisor);
Value zero = create.math.constant(elementType, 0);
Value falseVal = create.math.constant(rewriter.getI1Type(), 0);
Value isMathRemainderNonZero =
create.math.eq(create.math.eq(mathRemainder, zero), falseVal);
Value isDividendMinus = create.math.slt(dividend, zero);
Value isDivisorMinus = create.math.slt(divisor, zero);
Value exclusiveOrOfIsDividendMinusAndIsDivisorMinus = create.math.eq(
create.math.eq(isDividendMinus, isDivisorMinus), falseVal);
Value needAdjust = create.math.andi(
isMathRemainderNonZero, exclusiveOrOfIsDividendMinusAndIsDivisorMinus);
Value answer =
create.math.select(needAdjust, adjustedRemainder, mathRemainder);

#ifdef DEBUG_ONNX_MOD
create.krnl.printf("XXXX emitScalarOpFor<ONNXModOp>: diviend=", dividend,
dividend.getType());
create.krnl.printf(", divisor=", divisor, divisor.getType());
create.krnl.printf(
", mathReminder=", mathRemainder, mathRemainder.getType());
create.krnl.printf(
", adjustedReminder=", adjustedRemainder, adjustedRemainder.getType());
create.krnl.printf(", Answer=", answer, answer.getType());
create.krnl.printf("\n");
#endif

return answer;
}
llvm_unreachable("unsupported element type");
}
Expand Down
16 changes: 1 addition & 15 deletions src/Conversion/ONNXToKrnl/NN/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,8 @@ Value emitScalarOpFor<ONNXMaxPoolSingleOutOp>(
//
template <typename PoolOp>
std::vector<int64_t> getDilations(PoolOp poolOp) {
return {};
}

// MaxPool has dilations attribute.
template <>
std::vector<int64_t> getDilations<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp poolOp) {
std::vector<int64_t> dilations;
auto dilationsAttribute = poolOp.getDilationsAttr();
ArrayAttr dilationsAttribute = poolOp.getDilationsAttr();
bool isDefaultDilations = true;
for (auto dilation : dilationsAttribute.getValue()) {
int64_t dilationValue = dilation.cast<IntegerAttr>().getInt();
Expand All @@ -84,13 +77,6 @@ std::vector<int64_t> getDilations<ONNXMaxPoolSingleOutOp>(
//
template <typename PoolOp>
std::optional<ArrayAttr> getDilationAttr(PoolOp poolOp) {
return std::nullopt;
}

// MaxPool has dilations attribute.
template <>
std::optional<ArrayAttr> getDilationAttr<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp poolOp) {
return poolOp.getDilations();
}

Expand Down
16 changes: 12 additions & 4 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,21 +629,29 @@ ElementsAttr ElementsAttrBuilder::gather(
ArrayRef<int64_t> inputShape = inputType.getShape();
assert(axis < inputShape.size() && "gather axis out of range");
auto postAxisShape = inputShape.drop_front(axis + 1);
ArrayRef<int64_t> indicesShape = indices.getShapedType().getShape();
ShapedType indicesType = indices.getShapedType();
assert(indicesType.getElementType().isSignlessInteger() &&
"gather indices must be i32 or i64");
ArrayRef<int64_t> indicesShape = indicesType.getShape();
SmallVector<int64_t> outShape(inputShape.take_front(axis));
outShape.append(indicesShape.begin(), indicesShape.end());
outShape.append(postAxisShape.begin(), postAxisShape.end());
auto outType = inputType.clone(outShape);
return fromWideNums(outType, [&](MutableArrayRef<WideNum> dst) {
size_t postAxisNumElements = ShapedType::getNumElements(postAxisShape);
ArrayBuffer<WideNum> src = getElementsWideNums(input);
ArrayBuffer<int64_t> indicesArray = getElementsArray<int64_t>(indices);
// Convert indices of any signed int element type to int64 by
// first promoting to WideNum and then casting to int64.
// In practice we support both int32 and int64 in this way.
ArrayBuffer<WideNum> indicesWideNums = getElementsWideNums(indices);
ArrayRef<int64_t> indicesArray =
castArrayRef<int64_t, WideNum>(indicesWideNums.get());
size_t axisInputSize = inputShape[axis];
size_t inputBlockLen = axisInputSize * postAxisNumElements;
size_t outBlockLen = indicesArray.get().size() * postAxisNumElements;
size_t outBlockLen = indicesArray.size() * postAxisNumElements;
size_t start = 0;
WideNum *out = dst.begin();
for (int64_t idx : indicesArray.get()) {
for (int64_t idx : indicesArray) {
int64_t adjustedIdx = idx < 0 ? idx + axisInputSize : idx;
const WideNum *in = src.get().begin() + adjustedIdx * postAxisNumElements;
for (size_t offset = start; offset < dst.size(); offset += outBlockLen) {
Expand Down
Loading

0 comments on commit cc4b350

Please sign in to comment.