Skip to content

Commit

Permalink
Implement lowering of torch.aten.renorm
Browse files Browse the repository at this point in the history
  • Loading branch information
Branko Trifkovic committed May 28, 2024
1 parent e0a5adb commit 3b8c9ef
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 0 deletions.
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6604,6 +6604,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p,
Torch_IntType:$dim,
AnyTorchScalarType:$maxnorm
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRenormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
88 changes: 88 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4641,6 +4641,94 @@ LogicalResult AtenNormScalarOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenRenormOp
//===----------------------------------------------------------------------===//

LogicalResult AtenRenormOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());
auto selfDtype = selfType.getDtype();
auto selfRankedType = getSelf().getType().cast<RankedTensorType>();
int64_t selfRank = selfRankedType.getRank();

if (!selfType.hasDtype()) {
return success();
}

// Check if dtype is one of those supported by renorm operation.
// ComplexType will match any torch complex types, but each float must be
// checked individually.
if (selfDtype.isa<mlir::ComplexType>()) {
return emitOpError(
"lowering for complex type input tensor is currently unsuporrted");
}

if (!selfDtype.isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
mlir::Float64Type, mlir::ComplexType>()) {
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< selfDtype;
}

// According to the Pytoch documentation tensor need to be at least rank 2
if (selfRank <= 1) {
return emitOpError("renorm: input needs at least 2 dimensions, got ")
<< selfRank << " dimensions";
}

// Check if argument p is valid
auto pType = getP().getType();
double_t p;

if (pType.isa<mlir::ComplexType>()) {
return emitOpError("renorm: p must be real-valued");
}

if (!matchPattern(getP(), m_TorchConstantFloat(&p))) {
return emitOpError("renorm: p must be real-valued");
}

if (p < 0) {
return emitOpError("renorm: non-positive norm not supported");
}

// Check if argument maxnorm is valid
auto maxnormType = getMaxnorm().getType();
double_t maxnorm;
if (maxnormType.isa<mlir::ComplexType>()) {
return emitOpError("renorm: maxnorm must be real-valued");
}

if (!matchPattern(getP(), m_TorchConstantFloat(&maxnorm))) {
return emitOpError("renorm: maxnorm must be real-valued");
}

if (maxnorm < 0) {
return emitOpError("renorm: expected maxnorm to be >= 0 but got ")
<< maxnorm;
}

// Get the dimension
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) {
return emitOpError("dim must be a constant int");
}

// check if is dim is in the correct range
if (dim >= selfRank || dim < -selfRank) {
return emitOpError("Dimension out of range (expected to be in range of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
}

// Canonicalize dimension if necessary
if (dim < 0) {
dim += selfRank;
}

return success();
}

//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9689,6 +9689,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
Expand Down Expand Up @@ -12694,6 +12697,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"
Expand Down
135 changes: 135 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2088,6 +2088,140 @@ class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
};
} // namespace

// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229
// Decompose aten.renorm into: linalg_vector_norm
namespace {
class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRenormOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value self = op.getSelf();
Value dim = op.getDim();
Value p = op.getP();
Value maxnorm = op.getMaxnorm();

auto ndim = getTensorRank(self);
auto resType = cast<BaseTensorType>(self.getType());
Type dtype = resType.getDtype();

ArrayRef<int64_t> inputSizeArrayRef = resType.getSizes();
SmallVector<int64_t> inputSize;
inputSize.assign(inputSizeArrayRef.begin(), inputSizeArrayRef.end());

// Convert dim from Value to int
int64_t dimInt;
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(op,
"Unimplemented: dim not constant int");

// Define all constants
Value cstTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value cstZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(loc, 1);
Value cstNone = rewriter.create<ConstantNoneOp>(loc);

// Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... ,
// ndim-1]
llvm::SmallVector<Value> reduceDimsVector;
for (u_int64_t i = 0; i < ndim; i++) {
Value constI = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));

if (i == (u_int64_t)dimInt) {
continue;
}

reduceDimsVector.push_back(constI);
}

Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
reduceDimsVector);

// Make output shape for linalg.vector_norm operation
SmallVector<Value> inputSizeValue;
for (u_int64_t i = 0; i < inputSize.size(); i++) {
if (i != (u_int64_t)dimInt) {
inputSize[i] = 1;
}
inputSizeValue.push_back(
rewriter.create<Torch::ConstantIntOp>(loc, inputSize[i]));
}

// Prepare arguments for linalg.vector_norm
Value dtypeValue;
Type vectorNormOutType;

if (dtype.isa<mlir::Float16Type, mlir::BFloat16Type>()) {
dtype = cast<Type>(rewriter.getF32Type());
dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype);
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
} else {
dtypeValue = cstNone;
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
}

auto norm = rewriter.create<AtenLinalgVectorNormOp>(
loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue);

// Define epsiolon constant 10^-7
mlir::FloatType f64Type = rewriter.getF64Type();
Value epsValue = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(f64Type, 1e-7));

Value normPlusEps = rewriter.create<AtenAddScalarOp>(
loc, vectorNormOutType, norm, epsValue, cstOne);

Value sizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
inputSizeValue);

Value maxnormTensorValue =
createInitTensor(rewriter, loc, cast<BaseTensorType>(vectorNormOutType),
maxnorm, sizeList);

// Divide maxnorm and normPlusEps
auto divideMaxnormAndNorm = rewriter.create<AtenDivTensorOp>(
loc, vectorNormOutType, maxnormTensorValue, normPlusEps);

// Next few lines corespond to this pythorch code: norm_factor =
// torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
auto boolTensorType = rewriter.getType<ValueTensorType>(
cast<BaseTensorType>(vectorNormOutType).getOptionalSizes(),
rewriter.getI1Type());

Value greaterThanMaxnorm =
rewriter.create<AtenGtScalarOp>(loc, boolTensorType, norm, maxnorm);

Value cstOnetensor =
createInitTensor(rewriter, loc, cast<BaseTensorType>(vectorNormOutType),
cstOne, sizeList);

auto normFactor = rewriter.create<AtenWhereSelfOp>(
loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm,
cstOnetensor);

// Converte norm_factor to input dtype
Value normFactorFinal = rewriter.create<PrimsConvertElementTypeOp>(
loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()),
normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype()));

// Multiply input tensor with norm factor
auto output = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self,
normFactorFinal);

rewriter.replaceOpWithNewOp<AtenContiguousOp>(op, self.getType(), output,
/*memory_format*/ cstZero);

return success();
}
};
} // namespace

// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
// aten.add.Tensor and aten.mull.Tensor. See
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
Expand Down Expand Up @@ -7979,6 +8113,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenRenormOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
Expand Down
11 changes: 11 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,9 @@
"ElementwiseLogSigmoidModule_basic",
"ElementwiseHardshrinkStaticModule_basic",
"ElementwiseSoftshrinkStaticModule_basic",
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}

STABLEHLO_CRASHING_SET = set()
Expand Down Expand Up @@ -1913,6 +1916,8 @@
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}

MAKE_FX_TOSA_PASS_SET = (
Expand Down Expand Up @@ -1944,6 +1949,8 @@
"ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic",
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
Expand Down Expand Up @@ -2649,6 +2656,10 @@
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
# Failure - unknown
"BernoulliModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,9 @@ def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)

def aten〇renorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]:
return self

def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, None, False, None)

Expand Down Expand Up @@ -4262,6 +4265,20 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
return dtype
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(3,3)],
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64},
p=1,
dim=0,
maxnorm=5)
)
def aten〇renorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)

return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
emit(
Expand Down
Loading

0 comments on commit 3b8c9ef

Please sign in to comment.