Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement lowering of torch.aten.renorm #3388

Merged
merged 6 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6657,6 +6657,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p,
Torch_IntType:$dim,
AnyTorchScalarType:$maxnorm
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRenormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
74 changes: 74 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4655,6 +4655,80 @@ LogicalResult AtenNormScalarOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenRenormOp
//===----------------------------------------------------------------------===//

LogicalResult AtenRenormOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());

if (!selfType.hasDtype() || !selfType.hasSizes())
return success();
BaneTrifa marked this conversation as resolved.
Show resolved Hide resolved

auto inShape = selfType.getSizes();
int64_t selfRank = inShape.size();
auto selfDtype = selfType.getDtype();

if (!isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
mlir::Float64Type, mlir::ComplexType>(selfDtype))
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< selfDtype;

// According to the Pytoch documentation tensor need to be at least rank 2
if (selfRank <= 1)
return emitOpError("renorm: input needs at least 2 dimensions, got ")
<< selfRank << " dimensions";

// Check if argument p is valid
auto pType = getP().getType();

if (isa<mlir::ComplexType>(pType))
return emitOpError("renorm: p must be real-valued");

// The argument 'p' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'p' is within the correct
// range
int64_t pInt = 1;
double_t pDouble = 1;
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
!matchPattern(getP(), m_TorchConstantFloat(&pDouble)))
return success();

if (pInt <= 0 || pDouble <= 0)
return emitOpError("renorm: non-positive norm not supported");

// Check if argument maxnorm is valid
auto maxnormType = getMaxnorm().getType();
if (isa<mlir::ComplexType>(maxnormType))
return emitOpError("renorm: maxnorm must be real-valued");

// The argument 'maxnorm' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'maxnorm' is within the
// correct range
int64_t maxnormInt = 0;
double_t maxnormDouble = 0;
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble)))
return success();

if (maxnormInt < 0 || maxnormDouble < 0)
return emitOpError("renorm: expected maxnorm to be >= 0");

// Get the dimension
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
return success();

// check if is dim is in the correct range
if (dim >= selfRank || dim < -selfRank)
return emitOpError("Dimension out of range (expected to be in range of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;

return success();
}

//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10113,6 +10113,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
Expand Down Expand Up @@ -13150,6 +13153,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"
Expand Down
140 changes: 140 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,145 @@ class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
};
} // namespace

// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229
// Decompose aten.renorm into: linalg_vector_norm
namespace {
class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRenormOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value self = op.getSelf();
Value dim = op.getDim();
Value p = op.getP();
Value maxnorm = op.getMaxnorm();

// Prepare all necessary variables
auto ndim = getTensorRank(self);
auto resType = cast<BaseTensorType>(self.getType());

if (!resType.hasDtype() || !resType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"result should have dtype and sizes");
}

Type dtype = resType.getDtype();
if (isa<mlir::ComplexType>(dtype)) {
return rewriter.notifyMatchFailure(
op, "lowering of aten.renorm for complex inputs dtype is "
"currently unimplemented");
}

SmallVector<int64_t> inputSize(resType.getSizes());

// Convert dim from Value to int
int64_t dimInt;
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(op,
"Unimplemented: dim not constant int");

// Define all constants
Value cstTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value cstZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(loc, 1);
Value cstNone = rewriter.create<ConstantNoneOp>(loc);

// Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... ,
// ndim-1]
llvm::SmallVector<Value> reduceDimsVector;
for (u_int64_t i = 0; i < ndim; i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not build on MSVC. The portable type name (also used in other parts of this file) is uint64_t.

torch-mlir\lib\Dialect\Torch\Transforms\DecomposeComplexOps.cpp(2420): error C2065: 'u_int64_t': undeclared identifier

Sent a fix: #3519

if (i == (u_int64_t)dimInt)
continue;

Value constI = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));

reduceDimsVector.push_back(constI);
}

Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
reduceDimsVector);

// Make output shape for linalg.vector_norm operation
SmallVector<Value> inputSizeValue;
for (u_int64_t i = 0; i < inputSize.size(); i++) {
if (i != (u_int64_t)dimInt)
inputSize[i] = 1;

inputSizeValue.push_back(
rewriter.create<Torch::ConstantIntOp>(loc, inputSize[i]));
}

// Prepare arguments for linalg.vector_norm
Value dtypeValue;
Type vectorNormOutType;

if (isa<mlir::Float16Type, mlir::BFloat16Type>(dtype)) {
dtype = cast<Type>(rewriter.getF32Type());
BaneTrifa marked this conversation as resolved.
Show resolved Hide resolved
dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype);
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
} else {
dtypeValue = cstNone;
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
}

auto norm = rewriter.create<AtenLinalgVectorNormOp>(
loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue);

// Define epsiolon constant 10^-7
mlir::FloatType f64Type = rewriter.getF64Type();
Value epsValue = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(f64Type, 1e-7));

Value normPlusEps = rewriter.create<AtenAddScalarOp>(
loc, vectorNormOutType, norm, epsValue, cstOne);

Value maxnormTensorValue = rewriter.create<AtenFullLikeOp>(
loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone,
cstNone, cstNone, cstNone);

// Divide maxnorm and normPlusEps
auto divideMaxnormAndNorm = rewriter.create<AtenDivTensorOp>(
loc, vectorNormOutType, maxnormTensorValue, normPlusEps);

// Next few lines corespond to this pythorch code: norm_factor =
// torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
auto boolTensorType = rewriter.getType<ValueTensorType>(
cast<BaseTensorType>(vectorNormOutType).getOptionalSizes(),
BaneTrifa marked this conversation as resolved.
Show resolved Hide resolved
rewriter.getI1Type());

Value greaterThanMaxnorm =
rewriter.create<AtenGtScalarOp>(loc, boolTensorType, norm, maxnorm);

Value cstOnetensor = rewriter.create<AtenFullLikeOp>(
loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone,
cstNone, cstNone, cstNone);

auto normFactor = rewriter.create<AtenWhereSelfOp>(
loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm,
cstOnetensor);

// Converte norm_factor to input dtype
Value normFactorFinal = rewriter.create<PrimsConvertElementTypeOp>(
loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()),
normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype()));

// Multiply input tensor with norm factor
auto output = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self,
normFactorFinal);

rewriter.replaceOpWithNewOp<AtenContiguousOp>(op, self.getType(), output,
/*memory_format*/ cstZero);

return success();
}
};
} // namespace

// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
// aten.add.Tensor and aten.mull.Tensor. See
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
Expand Down Expand Up @@ -8080,6 +8219,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenRenormOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
Expand Down
12 changes: 12 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,9 @@
"ElementwiseLogSigmoidModule_basic",
"ElementwiseHardshrinkStaticModule_basic",
"ElementwiseSoftshrinkStaticModule_basic",
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}

STABLEHLO_CRASHING_SET = set()
Expand Down Expand Up @@ -1949,6 +1952,8 @@
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}

MAKE_FX_TOSA_PASS_SET = (
Expand Down Expand Up @@ -1982,6 +1987,8 @@
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
Expand Down Expand Up @@ -2690,6 +2697,11 @@
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"RenormModuleFloat32DynamicDims_basic",
# Failure - unknown
"BernoulliModule_basic",
"Conv_Transpose1dModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1990,6 +1990,9 @@ def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)

def aten〇renorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]:
return self

def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, None, False, None)

Expand Down Expand Up @@ -4401,6 +4404,20 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
return dtype
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(3,3)],
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64},
p=1,
dim=0,
maxnorm=5)
)
def aten〇renorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)

return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
emit(
Expand Down
Loading
Loading