diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 0842cff331fc..97fc5494621b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3700,6 +3700,12 @@ OpFoldResult AtenRemainderScalarOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { + auto intLhs = dyn_cast_or_null(adaptor.getA()); + auto intRhs = dyn_cast_or_null(adaptor.getB()); + if (intRhs && intRhs.getValue().getSExtValue() == 0) + return getA(); + if (intLhs && intLhs.getValue().getSExtValue() == 0) + return getB(); return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a + b; }); } @@ -3709,6 +3715,9 @@ OpFoldResult AtenAddIntOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { + if (getA() == getB()) + return IntegerAttr::get( + IntegerType::get(getContext(), 64, IntegerType::Signless), 0); return atenBinaryIntOperatorFoldHelper( adaptor.getOperands(), [](int64_t a, int64_t b) { return a - b; }); } diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 0e88bd8d6322..345b5e156125 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -86,42 +86,62 @@ LogicalResult getListFromTensor(Value value, SmallVector &vals) { getAsOpFoldResult(full.getFillValue())); return success(); } - // TODO: Add a case for unsqueeze of a primnumtotensorscalarop? + + if (auto unsqueeze = value.getDefiningOp()) { + Value usqSelf = unsqueeze.getSelf(); + if (auto numToTensor = + usqSelf.getDefiningOp()) { + vals.push_back(getAsOpFoldResult(numToTensor.getA())); + return success(); + } + } + + // A common rank 0 tensor producer + if (auto numToTensor = + value.getDefiningOp()) { + vals.push_back(getAsOpFoldResult(numToTensor.getA())); + return success(); + } // Last supported case: ValueTensorLiteralOp auto literalOp = value.getDefiningOp(); if (!literalOp) return failure(); - // Check the type. We make sure the type is not unsigned here before trying to - // materialize + // Check the type. auto ty = cast(literalOp.getType()); if (!ty.hasSizes() || ty.getSizes().size() > 1) return failure(); - int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1; + // make sure the type is not unsigned here before trying to materialize auto intTy = dyn_cast_or_null(ty.getDtype()); if (!intTy || intTy.isUnsigned()) return failure(); + // if we have a rank 0 literal, we will be adding one element to the list + int64_t listSize = ty.getSizes().size() == 1 ? ty.getSizes().front() : 1; + + if (listSize > kMaxFold) + return failure(); + + // check for a splat or dense attr auto splattr = dyn_cast_or_null(literalOp.getValue()); auto denseAttr = dyn_cast_or_null(literalOp.getValue()); if (!splattr && !denseAttr) return failure(); + // These are not mutually exclusive, so try splat first. if (splattr) { auto attr = splattr.getSplatValue(); vals.resize((int64_t)vals.size() + listSize, attr); + return success(); } - if (denseAttr && !splattr) { - for (auto e : denseAttr.getValues()) - vals.push_back(e); - } - - if ((int64_t)vals.size() != listSize) + // remaining case: denseAttr + if ((int64_t)denseAttr.getValues().size() != listSize) return failure(); - + for (auto e : denseAttr.getValues()) + vals.push_back(e); return success(); } @@ -143,6 +163,45 @@ Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy, // [scalarOpsA -> ListA -> TensorA] -> OpB. Then OpB will be able to // getListFromTensor(A), and further propagate scalarization. +namespace { +class PropagateAtenBroadcastToPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenBroadcastToOp op, + PatternRewriter &rewriter) const override { + constexpr int64_t kMaxFold = 16; + // for tensor, or tensor<1xsi64>, broadcasted to tensor, grab + // the element and convert to a full op. + auto ty = cast(op.getType()); + if (!ty.areAllSizesKnown() || ty.getSizes().size() != 1) + return failure(); + + if (ty.getSizes()[0] > kMaxFold) + return failure(); + + SmallVector fillFold; + if (failed(getListFromTensor(op.getSelf(), fillFold)) || + fillFold.size() != 1) + return failure(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector fillVals; + if (failed(materializeFolds(b, fillFold, fillVals))) + return failure(); + + Value size = b.create(ty.getSizes().front()); + Value sizeList = b.create( + rewriter.getType(rewriter.getType()), + size); + Value none = b.create(); + Value cstFalse = b.create(false); + rewriter.replaceOpWithNewOp(op, ty, sizeList, fillVals.front(), + none, none, none, cstFalse); + return success(); + } +}; +} // namespace + namespace { class PropagateAtenShapeToTensorPattern : public OpRewritePattern { @@ -541,9 +600,128 @@ class PropagateAtenItemPattern : public OpRewritePattern { }; } // namespace +namespace { + +template struct ArithmeticHelper { + static LogicalResult getAlphaAndVerify(OpTy &op, int64_t &alpha) { + alpha = 1; + return success(); + } +}; + +template <> struct ArithmeticHelper { + static LogicalResult getAlphaAndVerify(AtenAddTensorOp &op, int64_t &alpha) { + if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1) + return failure(); + return success(); + } +}; + +template <> struct ArithmeticHelper { + static LogicalResult getAlphaAndVerify(AtenSubTensorOp &op, int64_t &alpha) { + if (!matchPattern(op.getAlpha(), m_TorchConstantInt(&alpha)) || alpha != 1) + return failure(); + return success(); + } +}; + +template +class PropagateAtenArithmeticPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Check type + auto resultTy = cast(op.getType()); + if (resultTy.getSizes().size() > 1) + return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); + if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) + return rewriter.notifyMatchFailure(op, "not an int type"); + + int64_t alpha; + if (failed(ArithmeticHelper::getAlphaAndVerify(op, alpha))) + return rewriter.notifyMatchFailure(op, "alpha must be 1"); + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + SmallVector selfFold, otherFold; + if (failed(getListFromTensor(op.getSelf(), selfFold)) || + failed(getListFromTensor(op.getOther(), otherFold)) || + selfFold.size() != otherFold.size()) + return failure(); + SmallVector selfVals, otherVals; + if (failed(materializeFolds(b, selfFold, selfVals)) || + failed(materializeFolds(b, otherFold, otherVals))) + return failure(); + SmallVector resultFolds; + for (uint64_t i = 0; i < selfVals.size(); i++) { + resultFolds.push_back(b.createOrFold( + selfVals[i].getType(), selfVals[i], otherVals[i])); + } + SmallVector resultVals; + if (failed(materializeFolds(b, resultFolds, resultVals))) + return failure(); + + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, resultTy, resultVals.front()); + return success(); + } + + Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + /// ------ Fold Patterns ------ /// // These are shape-specific folding patterns +namespace { +class FoldAtenEqIntPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEqIntOp op, + PatternRewriter &rewriter) const override { + // replaces (size.int == 0) with false and adds an assert + // these comparisons are getting generated because onnx.Reshape considers 0 + // to mean "don't change this dim". However, if the size we are passing to + // onnx.Reshape is a tensor dim, this is definitely never supposed to be + // interpreted as "don't change this dim". + int64_t otherInt; + if (!matchPattern(op.getB(), m_TorchConstantInt(&otherInt)) || + otherInt != 0) + return failure(); + + // in case the shape is a product of two ints, check each + if (auto mulOp = op.getA().getDefiningOp()) { + Value self = mulOp.getA(); + Value other = mulOp.getB(); + Value selfEq = rewriter.create(op.getLoc(), self, op.getB()); + Value otherEq = + rewriter.create(op.getLoc(), other, op.getB()); + rewriter.replaceOpWithNewOp(op, selfEq, otherEq); + return success(); + } + + // if lhs is size.int op, assert size > 0 and replace with false. + if (auto sizeOp = op.getA().getDefiningOp()) { + Value selfGtOther = rewriter.create( + op.getLoc(), op.getType(), op.getA(), op.getB()); + rewriter.create( + op.getLoc(), selfGtOther, + rewriter.getStringAttr("Expected dim size > 0.")); + Value cstFalse = + rewriter.create(op.getLoc(), false); + rewriter.replaceOp(op, cstFalse); + return success(); + } + + return failure(); + } +}; +} // namespace + namespace { class FoldAtenTensorSplatPattern : public OpRewritePattern { public: @@ -594,16 +772,24 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { } // namespace namespace { -class FoldAtenSqueezePattern : public OpRewritePattern { +template +class FoldAtenSqueezePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSqueezeOp op, + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SqueezeOp op, PatternRewriter &rewriter) const override { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "Unknown result shape"); - if (auto atenFull = op.getSelf().getDefiningOp()) { + Value self = op.getSelf(); + if (auto atenFull = self.getDefiningOp()) { + // in the rank 0 case, just return the rank 0 scalar + if (resultTy.getSizes().size() == 0) { + rewriter.replaceOpWithNewOp( + op, resultTy, atenFull.getFillValue()); + return success(); + } SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) sizes.push_back(rewriter.create( @@ -874,9 +1060,16 @@ bool isPrimListOfInts(Operation *op) { return llvm::isa(listType.getContainedType()); } +bool isAnchorOp(Operation *op) { + return isa(op) || isa(op) || + isPrimListOfInts(op); +} + void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { - patterns.insert( + patterns.insert, + FoldAtenSqueezePattern, + FoldAtenUnsqueezePattern, FoldAtenWhereSelf, + FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>( patterns.getContext()); } @@ -885,10 +1078,21 @@ void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { } void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + // A note on division: onnx.Div from int, int -> int types rounds towards + // zero. The torch DivTensorOp actually doesn't allow returning an int dtype, + // but this was artificially plummbed through. Unfortunately, there is no + // scalar trunc div op in torch; however, we can safely assume all operands + // are positive so floor divide should be a sufficient scalar replacement. + patterns.insert< + PropagateAtenCatPattern, PropagateAtenIndexSelectPattern, + PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, + PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, + PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern, + PropagateAtenArithmeticPattern>( + patterns.getContext()); } void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { @@ -940,7 +1144,7 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { [&](Operation *op) { // Walking bottom-up, start adding ops when we reach an anchor point // (a prim list of ints) - if (isPrimListOfInts(op)) { + if (isAnchorOp(op)) { shapeCalculationOps.insert(op); return; } diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 7f6aa8a26ebb..166e2fda564e 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -75,6 +75,99 @@ func.func @literal_item() -> !torch.int { return %out : !torch.int } +// ----- + +// CHECK-LABEL: @arith_prop +func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + // CHECK: %[[float0:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[x0:.*]] = torch.aten.size.int %arg0, %[[int0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[x1:.*]] = torch.aten.size.int %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[int12:.*]] = torch.constant.int 12 + // CHECK: %[[int1_0:.*]] = torch.constant.int 1 + // CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[int12_1:.*]] = torch.constant.int 12 + // CHECK: %[[int1_2:.*]] = torch.constant.int 1 + // CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + // CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32> + %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %1 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> + %float0.000000e00 = torch.constant.float 0.000000e+00 + %int1 = torch.constant.int 1 + %2 = torch.vtensor.literal(dense<[12, 1]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + %int0 = torch.constant.int 0 + %3 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[2],si64> + %4 = torch.aten.div.Tensor %3, %2 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %5 = torch.aten.mul.Tensor %4, %2 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64> -> !torch.vtensor<[2],si64> + %6 = torch.aten.sub.Tensor %3, %5, %int1 : !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.int -> !torch.vtensor<[2],si64> + %7 = torch.aten.index_select %6, %int0, %1 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %8 = torch.aten.index_select %6, %int0, %0 : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %9 = torch.aten.item %7 : !torch.vtensor<[],si64> -> !torch.int + %10 = torch.aten.item %8 : !torch.vtensor<[],si64> -> !torch.int + %11 = torch.prim.ListConstruct %10, %9 : (!torch.int, !torch.int) -> !torch.list + %12 = torch.aten.constant_pad_nd %arg0, %11, %float0.000000e00 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + return %12 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: @broadcast_prop +func.func @broadcast_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.int { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: return %[[SZE]] : !torch.int + %dim = torch.constant.int 0 + %size = torch.aten.size.int %arg0, %dim : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %shape = torch.prim.NumToTensor.Scalar %size : !torch.int -> !torch.vtensor<[],si32> + %int3 = torch.constant.int 3 + %idx = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si32> + %bcastlist = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %bcast = torch.aten.broadcast_to %shape, %bcastlist : !torch.vtensor<[],si32>, !torch.list -> !torch.vtensor<[3],si32> + %select = torch.aten.index_select %bcast, %dim, %idx : !torch.vtensor<[3],si32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si32> + %out = torch.aten.item %select : !torch.vtensor<[],si32> -> !torch.int + %list = torch.prim.ListConstruct %out : (!torch.int) -> !torch.list + return %out : !torch.int +} + +// ----- + +// CHECK-LABEL: @eq_int_fold +func.func @eq_int_fold(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,1],f32> { + // CHECK: %[[int1:.*]] = torch.constant.int 1 + // CHECK: %[[int0:.*]] = torch.constant.int 0 + // CHECK: %[[sze0:.*]] = torch.aten.size.int %arg0, %[[int0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[sze1:.*]] = torch.aten.size.int %arg0, %[[int1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + // CHECK: %[[mul:.*]] = torch.aten.mul.int %[[sze0]], %[[sze1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[gt0:.*]] = torch.aten.gt.int %[[sze0]], %[[int0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[gt0]], "Expected dim size > 0." + // CHECK: %[[gt1:.*]] = torch.aten.gt.int %[[sze1]], %[[int0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[gt1]], "Expected dim size > 0." + // CHECK: %[[list:.*]] = torch.prim.ListConstruct %[[mul]], %[[int1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[view:.*]] = torch.aten.view %arg0, %[[list]] : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + // CHECK: return %[[view:.*]] : !torch.vtensor<[?,1],f32> + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %1 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int + %2 = torch.aten.mul.int %0, %1 : !torch.int, !torch.int -> !torch.int + %3 = torch.aten.eq.int %2, %int0 : !torch.int, !torch.int -> !torch.bool + %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int + %5 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],i1> + %6 = torch.prim.NumToTensor.Scalar %0 : !torch.int -> !torch.vtensor<[],si64> + %7 = torch.prim.NumToTensor.Scalar %2 : !torch.int -> !torch.vtensor<[],si64> + %8 = torch.aten.where.self %5, %6, %7 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> + %9 = torch.aten.item %8 : !torch.vtensor<[],si64> -> !torch.int + %10 = torch.prim.ListConstruct %9, %int1 : (!torch.int, !torch.int) -> !torch.list + %11 = torch.aten.view %arg0, %10 : !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?,1],f32> + return %11 : !torch.vtensor<[?,1],f32> +} // ----- diff --git a/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir b/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir index 038f5686d6a4..752398474ce7 100644 --- a/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir +++ b/test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir @@ -36,8 +36,8 @@ func.func @test_triu_decompose(%arg0: !torch.vtensor<[4,5],si64>) -> !torch.vten module { // CHECK-LABEL: func.func @test_scalarize func.func @test_scalarize(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.11.0"} { - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3 // CHECK: %[[ADD:.+]] = torch.aten.flatten.using_ints %arg0, %[[INT2]], %[[INT3]] : !torch.vtensor<[?,?,16,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?,1024],f32> %0 = torch.operator "onnx.Shape"(%arg0) : (!torch.vtensor<[?,?,16,64],f32>) -> !torch.vtensor<[4],si64> %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor} : () -> !torch.vtensor<[],si64>