From 4fbcef78199634474b1ab9d6ebc86cb4946bcbac Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Thu, 25 Apr 2024 19:56:23 +0000 Subject: [PATCH] [ONNX] Fix padding attributes for onnx.AveragePool --- lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 14aa41bef349..6a5b06683fbc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -356,9 +356,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - for (int64_t i : padding) { + // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] + // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all axes x. + int64_t paddingSizeHalf = padding.size()/2; + for (int64_t i = 0; i < paddingSizeHalf; ++i) { + // Check if onnx padding attribute is symmetric. + if(padding[i] != padding[i + paddingSizeHalf]) + return rewriter.notifyMatchFailure( + binder.op, "onnx padding attribute is not symmetric"); cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create(