Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for opset 19 AveragePool with dilations #2495

Merged
merged 2 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/SupportedONNXOps-cpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 19. Limitatio
| **Asinh** |9 - * | | |
| **Atan** |7 - * | | |
| **Atanh** |9 - * | | |
| **AveragePool** |6 - 18 | | |
| **AveragePool** |6 - * | | |
| **BatchNormalization** |6 - * |Training not supported. | |
| **Bernoulli** |none | | | |
| **Binarizer** |none | | | |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def GetI64ArrayAttrStridesAveragePool: NativeCodeCall<
"($0.getDefiningOp<ONNXAveragePoolOp>()))">;

def replaceONNXAveragePoolPattern : Pattern<
(ONNXAveragePoolOp:$res $x, $_, $_, $_, $_, $_, $_),
(ONNXAveragePoolOp:$res $x, $_, $_, $_, $_, $_, $_, $_),
[
// Get attributes using shape helper
(GetStrAttrPaddingtypeAveragePool:$padtype $res),
Expand Down
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ op_dialect_version_map_["Asin"] = {7};
op_dialect_version_map_["Asinh"] = {9};
op_dialect_version_map_["Atan"] = {7};
op_dialect_version_map_["Atanh"] = {9};
op_dialect_version_map_["AveragePool"] = {11};
op_dialect_version_map_["AveragePool"] = {19};
op_dialect_version_map_["BatchNormalization"] = {15};
op_dialect_version_map_["Bernoulli"] = {15};
op_dialect_version_map_["Binarizer"] = {1};
Expand Down
16 changes: 1 addition & 15 deletions src/Conversion/ONNXToKrnl/NN/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,8 @@ Value emitScalarOpFor<ONNXMaxPoolSingleOutOp>(
//
template <typename PoolOp>
std::vector<int64_t> getDilations(PoolOp poolOp) {
return {};
}

// MaxPool has dilations attribute.
template <>
std::vector<int64_t> getDilations<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp poolOp) {
std::vector<int64_t> dilations;
auto dilationsAttribute = poolOp.getDilationsAttr();
ArrayAttr dilationsAttribute = poolOp.getDilationsAttr();
bool isDefaultDilations = true;
for (auto dilation : dilationsAttribute.getValue()) {
int64_t dilationValue = dilation.cast<IntegerAttr>().getInt();
Expand All @@ -84,13 +77,6 @@ std::vector<int64_t> getDilations<ONNXMaxPoolSingleOutOp>(
//
template <typename PoolOp>
std::optional<ArrayAttr> getDilationAttr(PoolOp poolOp) {
return std::nullopt;
}

// MaxPool has dilations attribute.
template <>
std::optional<ArrayAttr> getDilationAttr<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp poolOp) {
return poolOp.getDilations();
}

Expand Down
7 changes: 2 additions & 5 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool",
```
output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1)
```
if ceil_mode is enabled

```
* pad_shape[i] is sum of pads along axis i
```
if ceil_mode is enabled `pad_shape[i]` is the sum of pads along axis `i`.

`auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following:
```
Expand All @@ -446,6 +442,7 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool",
DefaultValuedStrAttr<StrAttr, "NOTSET">:$auto_pad,
DefaultValuedAttr<SI64Attr, "0">:$ceil_mode,
DefaultValuedAttr<SI64Attr, "0">:$count_include_pad,
OptionalAttr<I64ArrayAttr>:$dilations,
I64ArrayAttr:$kernel_shape,
OptionalAttr<I64ArrayAttr>:$pads,
OptionalAttr<I64ArrayAttr>:$strides);
Expand Down
6 changes: 4 additions & 2 deletions src/Dialect/ONNX/ONNXOps/NN/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ LogicalResult ONNXAveragePoolOpShapeHelper::computeShape() {
ONNXAveragePoolOp poolOp = llvm::cast<ONNXAveragePoolOp>(op);
return customComputeShape(operandAdaptor.getX(), /*W*/ nullptr,
poolOp.getKernelShape(), poolOp.getAutoPad(), poolOp.getPads(),
poolOp.getStrides(),
/*dilation*/ std::nullopt, /*hasFilter*/ false, poolOp.getCeilMode());
poolOp.getStrides(), poolOp.getDilations(), /*hasFilter*/ false,
poolOp.getCeilMode());
}

} // namespace onnx_mlir
Expand Down Expand Up @@ -117,6 +117,8 @@ LogicalResult ONNXAveragePoolOp::verify() {
return failure();
if (failed(verifyStrides<ONNXAveragePoolOp>(this, spatialRank)))
return failure();
if (failed(verifyDilations<ONNXAveragePoolOp>(this, spatialRank)))
return failure();
if (failed(verifyPadding<ONNXAveragePoolOp>(this, spatialRank)))
return failure();
return success();
Expand Down
1 change: 0 additions & 1 deletion test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def get_test_models():

# ==OP== AveragePool
# ==MIN== 1
# ==UNSUPPORTED== 19
# TODO: original comment stated "same_upper/lower with dynamic padding-shapes not supported."
# However, I see the dyn shape test being done on all tests, including same_upper. So I am
# assuming that this comment is outdated.
Expand Down
2 changes: 1 addition & 1 deletion utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
'Asinh': [9],
'Atan': [7],
'Atanh': [9],
'AveragePool': [11],
'AveragePool': [19],
'BatchNormalization': [15],
'Bernoulli': [15],
'Binarizer': [1],
Expand Down
Loading