diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 233c6be7e5bf..ead29d59a59e 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10131,8 +10131,66 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %2 : !torch.tuple, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" -" return %0 : !torch.list\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" }\n" +" %2 = torch.aten.__isnot__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.int) {\n" +" %9 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %9 = func.call @__torch__.torch.jit._shape_functions.max_int() : () -> !torch.int\n" +" torch.prim.If.yield %9 : !torch.int\n" +" }\n" +" %4 = torch.aten.lt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5:3 = torch.prim.If %4 -> (!torch.int, !torch.int, !torch.int) {\n" +" %9 = torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %1, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1 : !torch.int\n" +" }\n" +" %11 = torch.aten.lt.int %10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %12 = torch.prim.If %11 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %10 : !torch.int\n" +" }\n" +" %13 = torch.aten.lt.int %3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %14 = torch.prim.If %13 -> (!torch.int) {\n" +" %20 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.add.int %3, %20 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %21 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %3 : !torch.int\n" +" }\n" +" %15 = torch.aten.lt.int %14, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.int) {\n" +" torch.prim.If.yield %int-1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %14 : !torch.int\n" +" }\n" +" %17 = torch.aten.add.int %16, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.add.int %12, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %19 = torch.aten.neg.int %arg4 : !torch.int -> !torch.int\n" +" torch.prim.If.yield %17, %18, %19 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" torch.prim.If.yield %1, %3, %arg4 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %6 = torch.derefine %5#0 : !torch.int to !torch.optional\n" +" %7 = torch.derefine %5#1 : !torch.int to !torch.optional\n" +" %8 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %6, %7, %5#2) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" +" return %8 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" " return %arg1 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 345b5e156125..9a85fbaa8646 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -17,6 +17,7 @@ #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" using namespace mlir; @@ -310,7 +311,9 @@ class PropagateAtenIndexSelectPattern auto selfShape = selfTy.getSizes(); int64_t selfRank = selfShape.size(); - dim = dim < 0 ? dim + selfRank : dim; + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return failure(); int64_t dimLength = elements.size(); if (selfShape[dim] != dimLength) return rewriter.notifyMatchFailure( @@ -362,6 +365,11 @@ class PropagateAtenSliceTensorPattern auto loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); + auto selfTy = cast(op.getSelf().getType()); + auto resultTy = cast(op.getType()); + if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "requires static sizes"); + SmallVector elements; if (failed(getListFromTensor(op.getSelf(), elements))) return failure(); @@ -379,39 +387,69 @@ class PropagateAtenSliceTensorPattern if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure(op, "requires a constant step"); - if (step < 0) - return rewriter.notifyMatchFailure(op, "requires a positive step value"); - - auto selfTy = cast(op.getSelf().getType()); auto selfShape = selfTy.getSizes(); + auto resultShape = resultTy.getSizes(); int64_t selfRank = selfShape.size(); // Correct for negative indexing: - dim = dim < 0 ? dim + selfRank : dim; + dim = toPositiveDim(dim, selfRank); + if (!isValidDim(dim, selfRank)) + return failure(); - int64_t dimLength = elements.size(); + int64_t dimLength = selfShape[dim]; start = start < 0 ? start + dimLength : start; end = end < 0 ? end + dimLength : end; + end = (end < 0) ? -1 : end; + end = (end < 0 && step > 0) ? 0 : end; start = start < 0 ? 0 : start; - end = end < 0 ? 0 : end; end = end > dimLength ? dimLength : end; - if (selfShape[dim] != dimLength) - return rewriter.notifyMatchFailure( - op, "dim length does not match number of elements"); + int64_t frontDimProd = 1, backDimProd = 1; + for (int64_t i = 0; i < selfRank; i++) { + if (i < dim) + frontDimProd *= selfShape[i]; + if (i > dim) + backDimProd *= selfShape[i]; + } + int64_t fullDimProd = frontDimProd * dimLength * backDimProd; + if (fullDimProd != (int64_t)elements.size()) + return rewriter.notifyMatchFailure(op, "unexpected number of elements."); + + // [d0,d1] i -> (i//d1, i % d1) -> (i//d1) * d1 + (i % d1) + // [d0,d1,d2] i -> (i//d2, i%d2) -> ((i//(d1*d2), (i//d2) % d1, i % d2) + + auto isSliceIdx = [&](int64_t i) { + int64_t dimidx = (i / backDimProd) % dimLength; + bool onStep = ((dimidx - start) % step == 0); + bool beforeEnd = (step < 0 && dimidx > end); + beforeEnd = beforeEnd || (step > 0 && dimidx < end); + bool afterBegin = (step < 0 && dimidx <= start); + afterBegin = afterBegin || (step > 0 && dimidx >= start); + return onStep && beforeEnd && afterBegin; + }; - for (int64_t i = 0; i < selfRank; ++i) { - if (i == dim) + auto flipIdx = [&](int64_t i) { + int64_t frontIdx = (i / (backDimProd * dimLength)); + int64_t dimIdx = (i / (backDimProd)) % dimLength; + int64_t flipDimIdx = dimLength - 1 - dimIdx; + int64_t backIdx = i % (backDimProd); + return frontIdx * (dimLength * backDimProd) + flipDimIdx * (backDimProd) + + backIdx; + }; + SmallVector selected; + for (int64_t i = 0; i < (int64_t)elements.size(); i++) { + if (!isSliceIdx(i)) continue; - if (selfShape[i] != 1) - return rewriter.notifyMatchFailure(op, - "expects unary non-dim dimension"); + int64_t index = (step > 0) ? i : flipIdx(i); + selected.push_back(elements[index]); } - SmallVector selected; - for (int i = start; i < end; i += step) - selected.push_back(elements[i]); + fullDimProd = (fullDimProd * resultShape[dim]) / selfShape[dim]; + if ((int64_t)selected.size() != fullDimProd) + return rewriter.notifyMatchFailure( + op, "Constructed slice values have an incompatable number of " + "elements to match the provided return type."); SmallVector values; if (failed(materializeFolds(b, selected, values))) @@ -424,6 +462,114 @@ class PropagateAtenSliceTensorPattern }; } // namespace +namespace { +class PropagateAtenTransposeIntPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTransposeIntOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + auto selfTy = cast(op.getSelf().getType()); + auto resultTy = cast(op.getType()); + if (!selfTy.areAllSizesKnown() || !resultTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure(op, "requires static sizes"); + + SmallVector elements; + if (failed(getListFromTensor(op.getSelf(), elements))) + return failure(); + + int64_t dim0, dim1; + if (!matchPattern(op.getDim0(), m_TorchConstantInt(&dim0))) + return failure(); + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return failure(); + + ArrayRef selfSizes = selfTy.getSizes(); + int64_t rank = selfSizes.size(); + + dim0 = toPositiveDim(dim0, rank); + dim1 = toPositiveDim(dim1, rank); + if (!isValidDim(dim0, rank) || !isValidDim(dim0, rank)) + return failure(); + + if (dim0 == dim1) { + rewriter.replaceOp(op, op.getSelf()); + return success(); + } + + if (dim0 > dim1) { + // swap dim0 and dim1 + dim0 = dim0 + dim1; + dim1 = dim0 - dim1; + dim0 -= dim1; + } + + // A generic transpose will look like... + // [frontDimsFlat, dim0, midDimsFlat, dim1, backDimsFlat] -> . + // [frontDimsFlat, dim1, midDimsFlat, dim0, backDimsFlat] . + // If any of front, mid, or back don't actually exist (e.g. dim0 = 0, or + // dim1 = dim0 + 1), the reassociation of completely flattened indices will + // remain unaffected by the artificially unsqueezed dims. + // -------- + // Setting some notation, let D0,D1,D2,D3,D4 be the respective dim sizes of + // "self". Let D'j be the transpose dim sizes, and Djk = Dj*Dk. Let fl_trans + // and fl_self be 1-D flattened tensors. Then: + // -------- + // fl_trans[i] = + // = trans[i/D'1234, i/(D'234) % D'1, i/(D'34) % D'2, i/D'4 % D'3, i % D'4] + // = trans[i/D1234, i/D214 % D3, i/D14 % D2, i/D4 % D1, i % D4] + // = self[i/D1234, i/D4 % D1, i/D14 % D2, i/D214 % D3, i % D4] + // = fl_self[dot.prod(indices, (D1234,D234,D34,D4,1))] . + // -------- + // reassoc(i) = (i/(D1234)) * D1234 + + // (i/D4 % D1) * D234 + + // (i/(D14) % D2) * D34 + + // (i/(D214) % D3) * D4 + + // (i % D4) . + + SmallVector D(5, 1); + int64_t i = -1; + // D[0] corresponds to flattened front dims + while (++i < dim0) + D[0] *= selfSizes[i]; + // D[1] is the earliest transpose dim + D[1] = selfSizes[i]; + // D[2] corresponds to flattened middle dims + while (++i < dim1) + D[2] *= selfSizes[i]; + // D[3] is the later transpose dim + D[3] = selfSizes[i]; + // D[4] corresponds to flattened back dims + while (++i < rank) + D[4] *= selfSizes[i]; + + int64_t D1234 = D[1] * D[2] * D[3] * D[4]; + int64_t fullDP = D[0] * D1234; + if (fullDP != (int64_t)elements.size()) + return failure(); + auto reassoc = [&](int64_t i) { + return (i / D1234) * D1234 + ((i / D[4]) % D[1]) * D[2] * D[3] * D[4] + + ((i / (D[1] * D[4])) % D[2]) * D[3] * D[4] + + ((i / (D[2] * D[1] * D[4])) % D[3]) * D[4] + (i % D[4]); + }; + SmallVector transposedFolds; + transposedFolds.reserve(fullDP); + for (int64_t i = 0; i < fullDP; i++) + transposedFolds.push_back(elements[reassoc(i)]); + + SmallVector transposedVals; + if (failed(materializeFolds(b, transposedFolds, transposedVals))) + return failure(); + + Value result = constructAtenTensorOpFromList(b, resultTy, transposedVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace namespace { class PropagateAtenWhereSelfPattern : public OpRewritePattern { public: @@ -600,6 +746,27 @@ class PropagateAtenItemPattern : public OpRewritePattern { }; } // namespace +namespace { +template +class PropagateAtenViewLikePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenViewLikeOp op, + PatternRewriter &rewriter) const override { + SmallVector selfFolds; + if (failed(getListFromTensor(op.getSelf(), selfFolds))) + return failure(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfVals; + if (failed(materializeFolds(b, selfFolds, selfVals))) + return failure(); + Value result = constructAtenTensorOpFromList(b, op.getType(), selfVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { template struct ArithmeticHelper { @@ -1065,6 +1232,34 @@ bool isAnchorOp(Operation *op) { isPrimListOfInts(op); } +// The argument to this function, op, is the use of some source op, srcOp. If +// this function returns true, we want to invalidate srcOp as a target for shape +// scalarization. +bool isInvalidValidViewConsumer(Operation *op, + SetVector &workList) { + // if the consumer isn't a view op, don't invalidate it + auto view = dyn_cast_or_null(op); + if (!view) + return false; + auto resultTy = dyn_cast(view.getType()); + if (!resultTy || !resultTy.hasDtype()) + return true; + // if the view op doesn't return integer types, then srcOp is not a shape + // tensor. note: prim lists will always get added before reaching this + // function call. + if (!isa(resultTy.getDtype())) + return true; + // check uses of the view op. + // If the view op has a use in our worklist, then it needs to be scalarized. + for (OpOperand &use : op->getUses()) { + Operation *userOp = use.getOwner(); + if (workList.contains(userOp)) + return false; + } + // invalidate, since the view op was added as a one-off for canonicalization. + return true; +} + void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, @@ -1078,6 +1273,11 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { } void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { + patterns.add>(patterns.getContext(), + /*benefit=*/10); + patterns.insert, + PropagateAtenViewLikePattern>( + patterns.getContext()); // A note on division: onnx.Div from int, int -> int types rounds towards // zero. The torch DivTensorOp actually doesn't allow returning an int dtype, // but this was artificially plummbed through. Unfortunately, there is no @@ -1088,6 +1288,7 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, + PropagateAtenTransposeIntPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, @@ -1105,9 +1306,6 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } @@ -1168,12 +1366,12 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { // shapeCalculationOps. It's consumer (%1) is indeed a shape // calculation op, but the size.int op is an elementary unit of shape // computation. No futher gathering of producers is necessary to - // reduce this. Similarly, don't add the `self` of a view op. + // reduce this. Similarly, don't always add the `self` of a view op. for (OpOperand &use : op->getUses()) { Operation *userOp = use.getOwner(); if (shapeCalculationOps.contains(userOp) && !isSourceOpForShapeScalarization(userOp) && - !isa(userOp)) { + !isInvalidValidViewConsumer(userOp, shapeCalculationOps)) { shapeCalculationOps.insert(op); return; } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 65cc18837edb..06437574d8f0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1903,7 +1903,33 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int = return upstream_shape_functions.unary(v), upstream_shape_functions.unary(g) def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: - return upstream_shape_functions.slice(self, dim, start, end, step) + start_val = start if start is not None else 0 + end_val = end if end is not None else upstream_shape_functions.max_int() + if (step < 0): + # Convert to equivalent postive-step parameters, which will require swapping start and end. + # If the parameters are in the normal range (0 <= start < d and -1 <= end <= start), then + # swapped_end = start + 1 and swapped_begin = end + 1. + # The shift of inclusion can cause issues if these parameters are not already resolved on the left. + # e.g. start = -1, end = -3 . So valid start is actually d-1, and valid end is d-3. Therefore, we + # should have swapped_end = d, but adding 1 to start before making it valid would result in an + # incorrect, but "valid", swapped_end = 0 for forward slicing. + # Additionally, if adding d doesn't make these values positive, but adding twice would, we need + # to clamp after resolving, otherwise the upstream function will try to resolve a second time. + if start_val < 0: + start_val += self[dim] + if start_val < 0: + start_val = 0 + if end_val < 0: + end_val += self[dim] + if end_val < 0: + end_val = -1 + + tmp = end_val + 1 + end_val = start_val + 1 + start_val = tmp + step = -step + return upstream_shape_functions.slice(self,dim,start_val,end_val,step) + def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: return size diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 166e2fda564e..f5193b701d8a 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -374,3 +374,262 @@ func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.li %59 = torch.prim.ListConstruct %58 : (!torch.int) -> !torch.list return %59 : !torch.list } + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_view$prop( +func.func @pytorch_dynamic_pad_export_view$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I144:.*]] = torch.constant.int 144 + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[x0]], %[[I144]], %[[x1]], %[[x2]], %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %7 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_slice$prop( +func.func @pytorch_dynamic_pad_export_slice$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[4,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[x2:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[I144:.*]] = torch.constant.int 144 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]], %[[I0_3]], %[[x1]], %[[x2]], %[[x0]], %[[I144]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4,2],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[4,2],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %8 : !torch.vtensor<[4,2],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_transpose$prop( +func.func @pytorch_dynamic_pad_export_transpose$prop(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.vtensor<[2,4],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[DIM3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[I0_3:.*]] = torch.constant.int 0 + // CHECK: %[[DIM1:.*]] = torch.constant.int 144 + // CHECK: %[[x3:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[DIM2]], %[[DIM0]], %[[I0_2]], %[[I0_3]], %[[DIM3]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[none:.*]] = torch.constant.none + // CHECK: %[[false:.*]] = torch.constant.bool false + // CHECK: %[[x4:.*]] = torch.aten.tensor %[[x3]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,4],si64> + // CHECK: return %[[x4]] : !torch.vtensor<[2,4],si64> + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %9 : !torch.vtensor<[2,4],si64> +} + +// ----- + +// CHECK-LABEL: @pytorch_dynamic_pad_export_full( +func.func @pytorch_dynamic_pad_export_full(%arg0: !torch.vtensor<[?,144,?,?],f32>) -> !torch.list { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[DIM2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,144,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[x1:.*]] = torch.prim.ListConstruct %[[DIM2]], %[[I0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: return %[[x1]] : !torch.list + %0 = torch.vtensor.literal(dense<0> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %2 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-9223372036854775807 = torch.constant.int -9223372036854775807 + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,144,?,?],f32> -> !torch.vtensor<[4],si64> + %4 = torch.prim.ListConstruct %3, %0 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %5 = torch.aten.cat %4, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %6 = torch.prim.ListConstruct %int-1, %int2 : (!torch.int, !torch.int) -> !torch.list + %7 = torch.aten.view %5, %6 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[4,2],si64> + %8 = torch.aten.slice.Tensor %7, %int0, %int-1, %int-9223372036854775807, %int-1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,2],si64> + %9 = torch.aten.transpose.int %8, %int0, %int1 : !torch.vtensor<[4,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,4],si64> + %10 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %11 = torch.aten.view %9, %10 : !torch.vtensor<[2,4],si64>, !torch.list -> !torch.vtensor<[8],si64> + %12 = torch.aten.index_select %11, %int0, %1 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %13 = torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + %14 = torch.aten.index_select %11, %int0, %2 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %15 = torch.aten.item %14 : !torch.vtensor<[1],si64> -> !torch.int + %16 = torch.prim.ListConstruct %13, %15: (!torch.int, !torch.int) -> !torch.list + return %16 : !torch.list +} + +// ----- + +// CHECK-LABEL: @transpose$prop_3d_0_1 +func.func @transpose$prop_3d_0_1(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3_3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE0_1]], %[[SIZE1_0]], %[[SIZE1_1]], %[[SIZE0_2]], %[[SIZE0_3]], %[[SIZE1_2]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64> + %0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %4 = torch.aten.cat %3, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,2,2],si64> + %7 = torch.aten.transpose.int %6, %int0, %int1 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64> + %8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list -> !torch.vtensor<[8],si64> + %10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int + %12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[2,2,2],si64> +} + +// ----- + +// CHECK-LABEL: @transpose$prop_3d_m1_0 +func.func @transpose$prop_3d_m1_0(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[2,2,2],si64> { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE0_0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE0_1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE0_2:.*]] = torch.aten.size.int %arg0, %[[I2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE0_3:.*]] = torch.aten.size.int %arg0, %[[I3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE1_0:.*]] = torch.aten.size.int %arg1, %[[I0_0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE1_1:.*]] = torch.aten.size.int %arg1, %[[I1_1]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE1_2:.*]] = torch.aten.size.int %arg1, %[[I2_2]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[I3_3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE1_3:.*]] = torch.aten.size.int %arg1, %[[I3_3]] : !torch.vtensor<[?,?,?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[SIZE0_0]], %[[SIZE1_0]], %[[SIZE0_2]], %[[SIZE1_2]], %[[SIZE0_1]], %[[SIZE1_1]], %[[SIZE0_3]], %[[SIZE1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[2,2,2],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[2,2,2],si64> + %0 = torch.vtensor.literal(dense<2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %int-1 = torch.constant.int -1 + %int2 = torch.constant.int 2 + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %1 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %2 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,?,?],f32> -> !torch.vtensor<[4],si64> + %3 = torch.prim.ListConstruct %1, %2 : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.list + %4 = torch.aten.cat %3, %int0 : !torch.list, !torch.int -> !torch.vtensor<[8],si64> + %5 = torch.prim.ListConstruct %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %6 = torch.aten.view %4, %5 : !torch.vtensor<[8],si64>, !torch.list -> !torch.vtensor<[2,2,2],si64> + %7 = torch.aten.transpose.int %6, %int-1, %int0 : !torch.vtensor<[2,2,2],si64>, !torch.int, !torch.int -> !torch.vtensor<[2,2,2],si64> + %8 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %9 = torch.aten.view %7, %8 : !torch.vtensor<[2,2,2],si64>, !torch.list -> !torch.vtensor<[8],si64> + %10 = torch.aten.index_select %9, %int0, %0 : !torch.vtensor<[8],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + %11 = torch.aten.item %10 : !torch.vtensor<[1],si64> -> !torch.int + %12 = torch.prim.ListConstruct %11 : (!torch.int) -> !torch.list + return %7 : !torch.vtensor<[2,2,2],si64> +}