Skip to content

Commit

Permalink
[ONNX] Fix bug in ONNXToTorch PadOp's pads tensor rearrangement (llvm…
Browse files Browse the repository at this point in the history
…#3485)

Fix the pad tensor rearrangement such that we change the representation
from [x1_begin, x2_begin, ..., x1_end, x2_end,...] to [xn_begin, xn_end,
...., x2_begin, x2_end, x1_begin, x1_end] where x1, x2 .. xn are the
dimensions of the pads tensor argument.

---------

Co-authored-by: zjgarvey <[email protected]>
Co-authored-by: zjgarvey <[email protected]>
  • Loading branch information
3 people authored Jul 3, 2024
1 parent ca0e906 commit 0fe7484
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
11 changes: 7 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2315,12 +2315,15 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

// The torch.pad op expects a different arrangement of padding pairs for
// each dimension as compared to the onnx.pad op. So, rearranging pad
// tensor to satisfy torch.pad op semantics.
// each dimension as compared to the onnx.pad op. Rearrange the pad
// tensor as shown below:
//
// [x1_begin, x2_begin, ..., x1_end, x2_end,...] ->
// [xn_begin, xn_end, ...., x2_begin, x2_end, x1_begin, x1_end]
SmallVector<Value> padsRearrange;
for (uint32_t i = 0; i < padsSize / 2; i++) {
for (uint32_t i = padsSize - 1; i >= padsSize / 2; i--) {
padsRearrange.emplace_back(padsTensorValue[i - padsSize / 2]);
padsRearrange.emplace_back(padsTensorValue[i]);
padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]);
}

Value padsSizeList =
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3664,7 +3664,9 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(outType.toBuiltinTensor(), values);
}

// If the input and output shapes are the same we can just fold:
// If the input and output shapes are the same & step == 1 we can fold:
if (!step || step.getValue().getSExtValue() != 1)
return nullptr;
for (size_t i = 0; i < inType.getSizes().size(); ++i) {
if (inType.getSizes()[i] != outType.getSizes()[i])
return nullptr;
Expand Down
2 changes: 0 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,8 +2216,6 @@
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"FlipModuleStaticShape_basic",
"FlipNegativeIndexModule_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand Down
2 changes: 1 addition & 1 deletion test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4],
// CHECK: %[[INT3:.+]] = torch.constant.int 3
// CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
// CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_1]], %[[ITEM_3]], %[[ITEM_0]], %[[ITEM_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[STR:.+]] = torch.constant.str "constant"
// CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32>
// CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32>
Expand Down

0 comments on commit 0fe7484

Please sign in to comment.