diff --git a/build_tools/ci/install_python_deps.sh b/build_tools/ci/install_python_deps.sh index 6b49689ce8ea..7acd900ec9ca 100755 --- a/build_tools/ci/install_python_deps.sh +++ b/build_tools/ci/install_python_deps.sh @@ -19,7 +19,7 @@ case $torch_version in ;; stable) echo "::group::installing stable torch" - python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu + python3 -m pip install --no-cache-dir -r $repo_root/stable-requirements.txt python3 -m pip install --no-cache-dir -r $repo_root/build-requirements.txt echo "::endgroup::" ;; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 1db603cc5aa0..e6d9fce11a06 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -127,8 +127,14 @@ class ConvertAtenMmOp : public OpConversionPattern { "mismatching contracting dimension for torch.aten.mm")); } + auto resultTy = op.getType().cast(); + auto resultDTy = resultTy.toBuiltinTensor().getElementType(); Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); + auto accumulatorDType = getDefaultAccType(rewriter, resultDTy); + if (accumulatorDType != resultDTy) { + elementType = accumulatorDType; + } Value zeroFill = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); @@ -163,6 +169,13 @@ class ConvertAtenMmOp : public OpConversionPattern { ValueRange{lhs, rhs}, zeroFill) .getResult(0); } + + if (accumulatorDType != resultDTy) { + Type resultElementType = + newResultType.cast().getElementType(); + matmul = torch_to_linalg::convertTensorToElementType( + rewriter, loc, matmul, resultElementType); + } // When constructed with just dynamic sizes, EmptyOp will have a result // type which has all `?`'s for dimensions, which might not be the result // type of `op`. The constraints on later linalg ops means that the result @@ -875,18 +888,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { castIndexToInt(weightDims[i]), strideIntValues[i])); } + Type accumulatorDType = getDefaultAccType(rewriter, resultDTy); Value initTensor = rewriter.create( - loc, getAsOpFoldResult(outDims), resultDTy); + loc, getAsOpFoldResult(outDims), accumulatorDType); Value outputTensor; + if (accumulatorDType != resultDTy) + bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias, + accumulatorDType); if (bias.getType().isa()) { Value c0; - if (resultDTy.isa()) { - c0 = rewriter.create(loc, - FloatAttr::get(resultDTy, 0.0)); - } else if (resultDTy.isa()) { - c0 = rewriter.create(loc, - IntegerAttr::get(resultDTy, 0)); + if (accumulatorDType.isa()) { + c0 = rewriter.create( + loc, FloatAttr::get(accumulatorDType, 0.0)); + } else if (accumulatorDType.isa()) { + c0 = rewriter.create( + loc, IntegerAttr::get(accumulatorDType, 0)); } outputTensor = rewriter.create(loc, c0, initTensor).getResult(0); @@ -973,6 +990,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "unimplemented: only 1D, 2D, and 3D convolution supported"); }; Type newResultType = getTypeConverter()->convertType(op.getType()); + if (accumulatorDType != resultDTy) { + Type resultElementType = + newResultType.cast().getElementType(); + conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, + resultElementType); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } @@ -1027,6 +1050,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); Type newResultType = getTypeConverter()->convertType(op.getType()); + if (accumulatorDType != resultDTy) { + Type resultElementType = + newResultType.cast().getElementType(); + conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, + resultElementType); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } @@ -1065,6 +1094,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); + if (accumulatorDType != resultDTy) { + Type resultElementType = + newResultType.cast().getElementType(); + conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, + resultElementType); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } @@ -1137,6 +1172,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); Type newResultType = getTypeConverter()->convertType(op.getType()); + if (accumulatorDType != resultDTy) { + Type resultElementType = + newResultType.cast().getElementType(); + conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, + resultElementType); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 2993a2697360..c6b23a2f4862 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -590,5 +590,5 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getI64Type(); if (inputType.isSignedInteger(64)) return rewriter.getI64Type(); - llvm::report_fatal_error("unhandled type for getDefaultAccType"); + return inputType; } diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 3622efafd9d2..c1e3ac8fc4e1 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import Optional +from typing import Optional, Union, Dict, Tuple, Any import warnings @@ -20,6 +20,7 @@ def export_and_import( f, *args, fx_importer: Optional[FxImporter] = None, + dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, func_name: str = "main", @@ -30,7 +31,7 @@ def export_and_import( if fx_importer is None: fx_importer = FxImporter(context=context, hooks=hooks) - prog = torch.export.export(f, args, kwargs) + prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) decomp_table = get_decomposition_table() prog = prog.run_decompositions(decomp_table) if experimental_support_mutation: diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index a51032273999..4d001290c179 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -10,8 +10,8 @@ from typing import Optional import torch -import torch.export import torch.nn as nn +from torch.export import Dim from torch_mlir import fx @@ -79,3 +79,19 @@ def forward(self, x): m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") print(m) + +@run +# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32> +def test_import_frozen_exported_program_with_dynamic_shapes(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tanh(x) + + batch = Dim("batch") + dynamic_shapes = {"x": {0: batch}} + m = fx.export_and_import(Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net") + print(m)