Skip to content

Commit

Permalink
Merge upstream up to 798bfd7
Browse files Browse the repository at this point in the history
  • Loading branch information
cferry-AMD authored Jul 24, 2024
2 parents 0921bb9 + 9a232db commit 7201a25
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 12 deletions.
2 changes: 1 addition & 1 deletion build_tools/ci/install_python_deps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ case $torch_version in
;;
stable)
echo "::group::installing stable torch"
python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install --no-cache-dir -r $repo_root/stable-requirements.txt
python3 -m pip install --no-cache-dir -r $repo_root/build-requirements.txt
echo "::endgroup::"
;;
Expand Down
55 changes: 48 additions & 7 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,14 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
"mismatching contracting dimension for torch.aten.mm"));
}

auto resultTy = op.getType().cast<ValueTensorType>();
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, resultDTy);
if (accumulatorDType != resultDTy) {
elementType = accumulatorDType;
}
Value zeroFill = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType);

Expand Down Expand Up @@ -163,6 +169,13 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
}

if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
matmul = torch_to_linalg::convertTensorToElementType(
rewriter, loc, matmul, resultElementType);
}
// When constructed with just dynamic sizes, EmptyOp will have a result
// type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result
Expand Down Expand Up @@ -875,18 +888,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
castIndexToInt(weightDims[i]), strideIntValues[i]));
}

Type accumulatorDType = getDefaultAccType(rewriter, resultDTy);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outDims), resultDTy);
loc, getAsOpFoldResult(outDims), accumulatorDType);

Value outputTensor;
if (accumulatorDType != resultDTy)
bias = torch_to_linalg::convertTensorToElementType(rewriter, loc, bias,
accumulatorDType);
if (bias.getType().isa<Torch::NoneType>()) {
Value c0;
if (resultDTy.isa<mlir::FloatType>()) {
c0 = rewriter.create<arith::ConstantOp>(loc,
FloatAttr::get(resultDTy, 0.0));
} else if (resultDTy.isa<mlir::IntegerType>()) {
c0 = rewriter.create<arith::ConstantOp>(loc,
IntegerAttr::get(resultDTy, 0));
if (accumulatorDType.isa<mlir::FloatType>()) {
c0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(accumulatorDType, 0.0));
} else if (accumulatorDType.isa<mlir::IntegerType>()) {
c0 = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(accumulatorDType, 0));
}
outputTensor =
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
Expand Down Expand Up @@ -973,6 +990,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "unimplemented: only 1D, 2D, and 3D convolution supported");
};
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down Expand Up @@ -1027,6 +1050,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = transposeValue(op.getLoc(), conv, outPerms, rewriter);

Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down Expand Up @@ -1065,6 +1094,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
.getResult(0);

Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down Expand Up @@ -1137,6 +1172,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
loc, outputTensor.getType(), conv,
expandOutputTensor.getReassociationIndices());
Type newResultType = getTypeConverter()->convertType(op.getType());
if (accumulatorDType != resultDTy) {
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,5 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getI64Type();
if (inputType.isSignedInteger(64))
return rewriter.getI64Type();
llvm::report_fatal_error("unhandled type for getDefaultAccType");
return inputType;
}
5 changes: 3 additions & 2 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import Optional
from typing import Optional, Union, Dict, Tuple, Any

import warnings

Expand All @@ -20,6 +20,7 @@ def export_and_import(
f,
*args,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
experimental_support_mutation: bool = False,
hooks: Optional[FxImporterHooks] = None,
func_name: str = "main",
Expand All @@ -30,7 +31,7 @@ def export_and_import(

if fx_importer is None:
fx_importer = FxImporter(context=context, hooks=hooks)
prog = torch.export.export(f, args, kwargs)
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
decomp_table = get_decomposition_table()
prog = prog.run_decompositions(decomp_table)
if experimental_support_mutation:
Expand Down
18 changes: 17 additions & 1 deletion test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from typing import Optional

import torch
import torch.export
import torch.nn as nn
from torch.export import Dim

from torch_mlir import fx

Expand Down Expand Up @@ -79,3 +79,19 @@ def forward(self, x):

m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net")
print(m)

@run
# CHECK-LABEL: test_import_frozen_exported_program_with_dynamic_shapes
# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,4],f32>) -> !torch.vtensor<[?,4],f32>
def test_import_frozen_exported_program_with_dynamic_shapes():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.tanh(x)

batch = Dim("batch")
dynamic_shapes = {"x": {0: batch}}
m = fx.export_and_import(Basic(), torch.randn(3, 4), dynamic_shapes=dynamic_shapes, func_name="test_net")
print(m)

0 comments on commit 7201a25

Please sign in to comment.