Skip to content

Commit

Permalink
add scalarization patterns to support dynamic pytorch pad exports (#3838
Browse files Browse the repository at this point in the history
)

1. Adds case handling for `aten.slice.tensor` shape inference with
negative strides. This is not technically allowed by native pytorch, but
it is useful for ONNX ingest. We were getting some incorrect shapes for
these negative strided slice ops.
2. Adds scalarization support for ops seen in pytorch pad exports to
ONNX. These are typically `aten.view` `aten.transpose.int` and
`aten.slice.Tensor` with negative strides (and rank 2).
3. Allows view op `self` to be added to the worklist conditionally,
based on whether the view op actually occurs as a middle point in a
shape computation.
  • Loading branch information
zjgarvey authored Nov 1, 2024
1 parent 39d69db commit 738d45d
Show file tree
Hide file tree
Showing 4 changed files with 568 additions and 27 deletions.
62 changes: 60 additions & 2 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10131,8 +10131,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" %int-1 = torch.constant.int -1\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %int0 : !torch.int\n"
" }\n"
" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %3 = torch.prim.If %2 -> (!torch.int) {\n"
" %9 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %9 : !torch.int\n"
" } else {\n"
" %9 = func.call @__torch__.torch.jit._shape_functions.max_int() : () -> !torch.int\n"
" torch.prim.If.yield %9 : !torch.int\n"
" }\n"
" %4 = torch.aten.lt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %5:3 = torch.prim.If %4 -> (!torch.int, !torch.int, !torch.int) {\n"
" %9 = torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.add.int %1, %20 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %21 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %1 : !torch.int\n"
" }\n"
" %11 = torch.aten.lt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
" torch.prim.If.yield %int0 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %10 : !torch.int\n"
" }\n"
" %13 = torch.aten.lt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %14 = torch.prim.If %13 -> (!torch.int) {\n"
" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %21 = torch.aten.add.int %3, %20 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %21 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %3 : !torch.int\n"
" }\n"
" %15 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %16 = torch.prim.If %15 -> (!torch.int) {\n"
" torch.prim.If.yield %int-1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %14 : !torch.int\n"
" }\n"
" %17 = torch.aten.add.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %18 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %19 = torch.aten.neg.int %arg4 : !torch.int -> !torch.int\n"
" torch.prim.If.yield %17, %18, %19 : !torch.int, !torch.int, !torch.int\n"
" } else {\n"
" torch.prim.If.yield %1, %3, %arg4 : !torch.int, !torch.int, !torch.int\n"
" }\n"
" %6 = torch.derefine %5#0 : !torch.int to !torch.optional<int>\n"
" %7 = torch.derefine %5#1 : !torch.int to !torch.optional<int>\n"
" %8 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %6, %7, %5#2) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %8 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
Expand Down
Loading

0 comments on commit 738d45d

Please sign in to comment.