Skip to content

Commit

Permalink
Implement lowering of torch.aten.renorm (#3388)
Browse files Browse the repository at this point in the history
Closes
[nod-ai/SHARK-ModelDev/issues/689](nod-ai/SHARK-ModelDev#689)

---------

Co-authored-by: Branko Trifkovic <[email protected]>
  • Loading branch information
BaneTrifa and Branko Trifkovic authored Jun 17, 2024
1 parent 59bade3 commit 676fa8c
Show file tree
Hide file tree
Showing 9 changed files with 382 additions and 0 deletions.
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6657,6 +6657,33 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
}];
}

def Torch_AtenRenormOp : Torch_Op<"aten.renorm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$p,
Torch_IntType:$dim,
AnyTorchScalarType:$maxnorm
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRenormOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenRenormOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasVerifier = 1;
}

def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
74 changes: 74 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4655,6 +4655,80 @@ LogicalResult AtenNormScalarOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// AtenRenormOp
//===----------------------------------------------------------------------===//

LogicalResult AtenRenormOp::verify() {

auto selfType = cast<BaseTensorType>(getSelf().getType());

if (!selfType.hasDtype() || !selfType.hasSizes())
return success();

auto inShape = selfType.getSizes();
int64_t selfRank = inShape.size();
auto selfDtype = selfType.getDtype();

if (!isa<mlir::Float16Type, mlir::BFloat16Type, mlir::Float32Type,
mlir::Float64Type, mlir::ComplexType>(selfDtype))
return emitOpError(
"expected a float or complex type for input tensor, but got ")
<< selfDtype;

// According to the Pytoch documentation tensor need to be at least rank 2
if (selfRank <= 1)
return emitOpError("renorm: input needs at least 2 dimensions, got ")
<< selfRank << " dimensions";

// Check if argument p is valid
auto pType = getP().getType();

if (isa<mlir::ComplexType>(pType))
return emitOpError("renorm: p must be real-valued");

// The argument 'p' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'p' is within the correct
// range
int64_t pInt = 1;
double_t pDouble = 1;
if (!matchPattern(getP(), m_TorchConstantInt(&pInt)) &&
!matchPattern(getP(), m_TorchConstantFloat(&pDouble)))
return success();

if (pInt <= 0 || pDouble <= 0)
return emitOpError("renorm: non-positive norm not supported");

// Check if argument maxnorm is valid
auto maxnormType = getMaxnorm().getType();
if (isa<mlir::ComplexType>(maxnormType))
return emitOpError("renorm: maxnorm must be real-valued");

// The argument 'maxnorm' can be either an integer or a floating-point number,
// so we need to consider both options and check if 'maxnorm' is within the
// correct range
int64_t maxnormInt = 0;
double_t maxnormDouble = 0;
if (!matchPattern(getMaxnorm(), m_TorchConstantInt(&maxnormInt)) &&
!matchPattern(getMaxnorm(), m_TorchConstantFloat(&maxnormDouble)))
return success();

if (maxnormInt < 0 || maxnormDouble < 0)
return emitOpError("renorm: expected maxnorm to be >= 0");

// Get the dimension
int64_t dim;
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
return success();

// check if is dim is in the correct range
if (dim >= selfRank || dim < -selfRank)
return emitOpError("Dimension out of range (expected to be in range of [")
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;

return success();
}

//===----------------------------------------------------------------------===//
// AtenPermuteOp
//===----------------------------------------------------------------------===//
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10119,6 +10119,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.renorm\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %none = torch.constant.none\n"
Expand Down Expand Up @@ -13162,6 +13165,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.renorm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %int5 = torch.constant.int 5\n"
Expand Down
140 changes: 140 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,145 @@ class DecomposeAtenMvOp : public OpRewritePattern<AtenMvOp> {
};
} // namespace

// https://github.com/pytorch/pytorch/blob/9dec41b684a4284c4e052e295314c23f0f942fec/torch/_refs/__init__.py#L3229
// Decompose aten.renorm into: linalg_vector_norm
namespace {
class DecomposeAtenRenormOp : public OpRewritePattern<AtenRenormOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRenormOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value self = op.getSelf();
Value dim = op.getDim();
Value p = op.getP();
Value maxnorm = op.getMaxnorm();

// Prepare all necessary variables
auto ndim = getTensorRank(self);
auto resType = cast<BaseTensorType>(self.getType());

if (!resType.hasDtype() || !resType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"result should have dtype and sizes");
}

Type dtype = resType.getDtype();
if (isa<mlir::ComplexType>(dtype)) {
return rewriter.notifyMatchFailure(
op, "lowering of aten.renorm for complex inputs dtype is "
"currently unimplemented");
}

SmallVector<int64_t> inputSize(resType.getSizes());

// Convert dim from Value to int
int64_t dimInt;
if (!matchPattern(dim, m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(op,
"Unimplemented: dim not constant int");

// Define all constants
Value cstTrue = rewriter.create<ConstantBoolOp>(loc, true);
Value cstZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
Value cstOne = rewriter.create<Torch::ConstantIntOp>(loc, 1);
Value cstNone = rewriter.create<ConstantNoneOp>(loc);

// Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... ,
// ndim-1]
llvm::SmallVector<Value> reduceDimsVector;
for (u_int64_t i = 0; i < ndim; i++) {
if (i == (u_int64_t)dimInt)
continue;

Value constI = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));

reduceDimsVector.push_back(constI);
}

Value reduceDimsList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
reduceDimsVector);

// Make output shape for linalg.vector_norm operation
SmallVector<Value> inputSizeValue;
for (u_int64_t i = 0; i < inputSize.size(); i++) {
if (i != (u_int64_t)dimInt)
inputSize[i] = 1;

inputSizeValue.push_back(
rewriter.create<Torch::ConstantIntOp>(loc, inputSize[i]));
}

// Prepare arguments for linalg.vector_norm
Value dtypeValue;
Type vectorNormOutType;

if (isa<mlir::Float16Type, mlir::BFloat16Type>(dtype)) {
dtype = cast<Type>(rewriter.getF32Type());
dtypeValue = getDtypeIntValueForType(rewriter, loc, dtype);
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
} else {
dtypeValue = cstNone;
vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype);
}

auto norm = rewriter.create<AtenLinalgVectorNormOp>(
loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue);

// Define epsiolon constant 10^-7
mlir::FloatType f64Type = rewriter.getF64Type();
Value epsValue = rewriter.create<ConstantFloatOp>(
loc, rewriter.getFloatAttr(f64Type, 1e-7));

Value normPlusEps = rewriter.create<AtenAddScalarOp>(
loc, vectorNormOutType, norm, epsValue, cstOne);

Value maxnormTensorValue = rewriter.create<AtenFullLikeOp>(
loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone,
cstNone, cstNone, cstNone);

// Divide maxnorm and normPlusEps
auto divideMaxnormAndNorm = rewriter.create<AtenDivTensorOp>(
loc, vectorNormOutType, maxnormTensorValue, normPlusEps);

// Next few lines corespond to this pythorch code: norm_factor =
// torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
auto boolTensorType = rewriter.getType<ValueTensorType>(
cast<BaseTensorType>(vectorNormOutType).getOptionalSizes(),
rewriter.getI1Type());

Value greaterThanMaxnorm =
rewriter.create<AtenGtScalarOp>(loc, boolTensorType, norm, maxnorm);

Value cstOnetensor = rewriter.create<AtenFullLikeOp>(
loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone,
cstNone, cstNone, cstNone);

auto normFactor = rewriter.create<AtenWhereSelfOp>(
loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm,
cstOnetensor);

// Converte norm_factor to input dtype
Value normFactorFinal = rewriter.create<PrimsConvertElementTypeOp>(
loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()),
normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype()));

// Multiply input tensor with norm factor
auto output = rewriter.create<AtenMulTensorOp>(loc, self.getType(), self,
normFactorFinal);

rewriter.replaceOpWithNewOp<AtenContiguousOp>(op, self.getType(), output,
/*memory_format*/ cstZero);

return success();
}
};
} // namespace

// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
// aten.add.Tensor and aten.mull.Tensor. See
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
Expand Down Expand Up @@ -8081,6 +8220,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenNormScalarOptDimOp>();
target.addIllegalOp<AtenSelectIntOp>();
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenRenormOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
Expand Down
12 changes: 12 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,9 @@
"ElementwiseLogSigmoidModule_basic",
"ElementwiseHardshrinkStaticModule_basic",
"ElementwiseSoftshrinkStaticModule_basic",
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}

STABLEHLO_CRASHING_SET = set()
Expand Down Expand Up @@ -1949,6 +1952,8 @@
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}

MAKE_FX_TOSA_PASS_SET = (
Expand Down Expand Up @@ -1982,6 +1987,8 @@
"ViewSizeDimLedByCollapsedOnesModule_basic",
"ViewSizeFromOtherTensor_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
}
) - {
### Test failing in make_fx_tosa but not in tosa
Expand Down Expand Up @@ -2695,6 +2702,11 @@
"IndexPutHackedTwin3DIntNonAccumulateModule_basic",
# RuntimeError: unsupported input type: Device
"PrimsIotaModule_basic",
# Error: 'aten::renorm' to ONNX opset version 17 is not supported.
"RenormModuleFloat16_basic",
"RenormModuleFloat32NegativeDim_basic",
"RenormModuleFloat32_basic",
"RenormModuleFloat32DynamicDims_basic",
# Failure - unknown
"BernoulliModule_basic",
"Conv_Transpose1dModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,9 @@ def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim
def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0)

def aten〇renorm〡shape(self: List[int], p: float, dim: int, maxnorm: float) -> List[int]:
return self

def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, None, False, None)

Expand Down Expand Up @@ -4416,6 +4419,20 @@ def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[U
return dtype
return aten〇std〡dtype(self_rank_dtype)

@check_dtype_function(
_check_tensors_with_the_same_dtype(
tensor_shapes=[(3,3)],
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64},
p=1,
dim=0,
maxnorm=5)
)
def aten〇renorm〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex], dim: int, maxnorm: Union[int, float, complex]) -> int:
self_rank, self_dtype = self_rank_dtype
assert not is_integer_dtype(self_dtype)

return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(
num_of_tensors=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)"
)
emit("aten::renorm : (Tensor, Scalar, int, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True)
emit("aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)")
emit(
Expand Down
Loading

0 comments on commit 676fa8c

Please sign in to comment.