diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 14aa41bef349..6a5b06683fbc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -356,9 +356,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cstKernel.push_back(rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - for (int64_t i : padding) { + // Onnx pads format: [x1_begin, x2_begin…x1_end, x2_end,…] + // Pytorch pads format: [x1, x2,...] or [x], assume begin==end for all axes x. + int64_t paddingSizeHalf = padding.size()/2; + for (int64_t i = 0; i < paddingSizeHalf; ++i) { + // Check if onnx padding attribute is symmetric. + if(padding[i] != padding[i + paddingSizeHalf]) + return rewriter.notifyMatchFailure( + binder.op, "onnx padding attribute is not symmetric"); cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create(